| # Copyright 2024 Databricks | |
| # SPDX-License-Identifier: Apache-2.0 | |
| from typing import Any | |
| # NOTE: Torch needs to be imported before the custom | |
| # extensions. Otherwise libc10.so cannot be found. | |
| import torch | |
| # Wrap this in a try-block with better error message and | |
| # instructions for building the c++ operations. | |
| try: | |
| from .._ops import ops # type: ignore | |
| except ModuleNotFoundError as e: | |
| raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e | |
| # Autograd wrapper for replicate kernel. | |
| class ReplicateOp(torch.autograd.Function): | |
| def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): | |
| ctx.save_for_backward(bins) | |
| out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) | |
| ops.replicate_forward(x, bins, out) | |
| return out | |
| def backward(ctx: Any, grad: torch.Tensor): | |
| bins, = ctx.saved_tensors | |
| out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) | |
| ops.replicate_backward(grad, bins, out) | |
| return out, None, None | |
| replicate = ReplicateOp.apply | |