| # Copyright 2024 Databricks | |
| # SPDX-License-Identifier: Apache-2.0 | |
| from typing import Any | |
| # NOTE: Torch needs to be imported before the custom | |
| # extensions. Otherwise libc10.so cannot be found. | |
| import torch | |
| # Wrap this in a try-block with better error message and | |
| # instructions for building the c++ operations. | |
| try: | |
| from .._ops import ops # type: ignore | |
| except ModuleNotFoundError as e: | |
| raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e | |
| # Autograd wrapper for topology kernel. | |
| # NOTE: Does not support gradients. | |
| class TopologyOp(torch.autograd.Function): | |
| def forward( | |
| ctx: Any, | |
| padded_bins: torch.Tensor, | |
| block_size: int, | |
| output_block_rows: int, | |
| output_block_columns: int, | |
| ): | |
| out = torch.empty( | |
| output_block_rows * output_block_columns, | |
| dtype=torch.int16, | |
| device=padded_bins.device, | |
| ) | |
| ops.indices( | |
| padded_bins, | |
| block_size, | |
| output_block_rows, | |
| output_block_columns, | |
| out, | |
| ) | |
| return out | |
| topology = TopologyOp.apply | |