leonardlin's picture
Add ROCm build artifacts and HIP backend
1e407f0
raw
history blame
1.15 kB
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0
from typing import Any
# NOTE: Torch needs to be imported before the custom
# extensions. Otherwise libc10.so cannot be found.
import torch
# Wrap this in a try-block with better error message and
# instructions for building the c++ operations.
try:
from .._ops import ops # type: ignore
except ModuleNotFoundError as e:
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
# Autograd wrapper for topology kernel.
# NOTE: Does not support gradients.
class TopologyOp(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
padded_bins: torch.Tensor,
block_size: int,
output_block_rows: int,
output_block_columns: int,
):
out = torch.empty(
output_block_rows * output_block_columns,
dtype=torch.int16,
device=padded_bins.device,
)
ops.indices(
padded_bins,
block_size,
output_block_rows,
output_block_columns,
out,
)
return out
topology = TopologyOp.apply