| import torch | |
| import torch.nn.functional as F | |
| from ._ops import add_op_namespace_prefix | |
| def _silu_and_mul(x: torch.Tensor) -> torch.Tensor: | |
| d = x.shape[-1] // 2 | |
| return F.silu(x[..., :d]) * x[..., d:] | |
| def backward(ctx, grad_output): | |
| x = ctx.saved_tensors[0] | |
| d = x.shape[-1] // 2 | |
| x1, x2 = x[..., :d], x[..., d:] | |
| sigmoid_x1 = torch.sigmoid(x1) | |
| silu_x1 = F.silu(x1) | |
| dsilu_dx1 = sigmoid_x1 + silu_x1 * (1 - sigmoid_x1) | |
| dx1 = grad_output * x2 * dsilu_dx1 | |
| dx2 = grad_output * silu_x1 | |
| return torch.cat([dx1, dx2], dim=-1) | |
| def setup_context(ctx, inputs, output): | |
| (x,) = inputs | |
| ctx.save_for_backward(x) | |
| _silu_and_mul.register_autograd(backward, setup_context=setup_context) | |
| def _(x: torch.Tensor) -> torch.Tensor: | |
| return x.new_empty(x.shape[0], x.shape[1] // 2) | |