leonardlin's picture
Add ROCm build artifacts and HIP backend
1e407f0
raw
history blame
1.15 kB
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0
from typing import Any
# NOTE: Torch needs to be imported before the custom
# extensions. Otherwise libc10.so cannot be found.
import torch
# Wrap this in a try-block with better error message and
# instructions for building the c++ operations.
try:
from .._ops import ops # type: ignore
except ModuleNotFoundError as e:
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
# Autograd wrapper for replicate kernel.
class ReplicateOp(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
ctx.save_for_backward(bins)
out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
ops.replicate_forward(x, bins, out)
return out
@staticmethod
def backward(ctx: Any, grad: torch.Tensor):
bins, = ctx.saved_tensors
out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
ops.replicate_backward(grad, bins, out)
return out, None, None
replicate = ReplicateOp.apply