danieldk HF Staff commited on
Commit
9c46652
·
1 Parent(s): cdb83b9

Add build files

Browse files
build.toml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "python-invalid-dep"
3
+ license = "apache-2.0"
4
+ backends = [
5
+ "cpu",
6
+ "cuda",
7
+ "metal",
8
+ "rocm",
9
+ "xpu",
10
+ ]
11
+ version = 1
12
+
13
+ [general.hub]
14
+ repo-id = "kernels-tests/python-invalid-dep"
flake.nix ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for versions test kernel";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernels";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genKernelFlakeOutputs {
14
+ inherit self;
15
+ path = ./.;
16
+ };
17
+ }
torch-ext/python_invalid_dep/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ._ops import ops
4
+ from .op import _silu_and_mul
5
+
6
+
7
+ def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
8
+ return ops.silu_and_mul(x)
9
+
10
+
11
+ __all__ = ["silu_and_mul"]
torch-ext/python_invalid_dep/op.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from ._ops import add_op_namespace_prefix
5
+
6
+
7
+ @torch.library.custom_op(add_op_namespace_prefix("silu_and_mul"), mutates_args=())
8
+ def _silu_and_mul(x: torch.Tensor) -> torch.Tensor:
9
+ d = x.shape[-1] // 2
10
+ return F.silu(x[..., :d]) * x[..., d:]
11
+
12
+
13
+ def backward(ctx, grad_output):
14
+ x = ctx.saved_tensors[0]
15
+ d = x.shape[-1] // 2
16
+ x1, x2 = x[..., :d], x[..., d:]
17
+ sigmoid_x1 = torch.sigmoid(x1)
18
+ silu_x1 = F.silu(x1)
19
+ dsilu_dx1 = sigmoid_x1 + silu_x1 * (1 - sigmoid_x1)
20
+ dx1 = grad_output * x2 * dsilu_dx1
21
+ dx2 = grad_output * silu_x1
22
+ return torch.cat([dx1, dx2], dim=-1)
23
+
24
+
25
+ def setup_context(ctx, inputs, output):
26
+ (x,) = inputs
27
+ ctx.save_for_backward(x)
28
+
29
+
30
+ _silu_and_mul.register_autograd(backward, setup_context=setup_context)
31
+
32
+
33
+ @_silu_and_mul.register_fake
34
+ def _(x: torch.Tensor) -> torch.Tensor:
35
+ return x.new_empty(x.shape[0], x.shape[1] // 2)