drbh commited on
Commit
f9791fd
·
1 Parent(s): c46b1d4

feat: push full template and build to repo

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ build
README.md ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # first-kernel
2
+
3
+ A custom kernel for PyTorch.
4
+
5
+ ## Installation
6
+
7
+ ```bash
8
+ pip install drbh/first-kernel
9
+ ```
10
+
11
+ ## Usage
12
+
13
+ ```python
14
+ import torch
15
+ from first_kernel import first_kernel
16
+
17
+ # Create input tensor
18
+ x = torch.randn(1024, 1024, device="cuda")
19
+
20
+ # Run kernel
21
+ result = first_kernel(x)
22
+ ```
23
+
24
+ ## Development
25
+
26
+ ### Building
27
+
28
+ ```bash
29
+ nix develop
30
+ nix run .#build-and-copy
31
+ ```
32
+
33
+ ### Testing
34
+
35
+ ```bash
36
+ nix develop .#test
37
+ pytest tests/
38
+ ```
39
+
40
+ ### Test as a `kernels` user
41
+
42
+ ```bash
43
+ uv run example.py
44
+ ```
45
+
46
+ ## License
47
+
48
+ Apache 2.0
build.toml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ backends = ["metal"]
3
+ name = "first-kernel"
4
+ version = 1
5
+
6
+ [kernel.first_kernel_metal]
7
+ backend = "metal"
8
+ depends = ["torch"]
9
+ src = [
10
+ "first_kernel_metal/first_kernel.mm",
11
+ "first_kernel_metal/first_kernel.metal",
12
+ ]
13
+
14
+ [torch]
15
+ src = [
16
+ "torch-ext/torch_binding.cpp",
17
+ "torch-ext/torch_binding.h",
18
+ ]
example.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.13"
3
+ # dependencies = [
4
+ # "kernels",
5
+ # "numpy",
6
+ # "torch",
7
+ # ]
8
+ # ///
9
+
10
+ import platform
11
+ from pathlib import Path
12
+
13
+ import kernels
14
+ import torch
15
+
16
+ # Load the locally built kernel
17
+ kernel = kernels.get_local_kernel(Path("build"), "first_kernel")
18
+
19
+ # Select device
20
+ if platform.system() == "Darwin":
21
+ device = torch.device("mps")
22
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
23
+ device = torch.device("xpu")
24
+ elif torch.version.cuda is not None and torch.cuda.is_available():
25
+ device = torch.device("cuda")
26
+ else:
27
+ device = torch.device("cpu")
28
+
29
+ print(f"Using device: {device}")
30
+
31
+ # Create input tensor
32
+ x = torch.tensor([1.0, 2.0, 3.0], device=device)
33
+ print(f"Input: {x}")
34
+
35
+ # Run kernel (adds 1 to each element)
36
+ result = kernel.first_kernel(x)
37
+ print(f"Output: {result}")
38
+
39
+ # Verify result
40
+ expected = x + 1.0
41
+ assert torch.allclose(result, expected), "Kernel output doesn't match expected!"
42
+ print("Success!")
first_kernel_metal/first_kernel.metal ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #include <metal_stdlib>
2
+ using namespace metal;
3
+
4
+ kernel void first_kernel_kernel(device const float *input [[buffer(0)]],
5
+ device float *output [[buffer(1)]],
6
+ uint index [[thread_position_in_grid]]) {
7
+ output[index] = input[index] + 1.0f;
8
+ }
first_kernel_metal/first_kernel.mm ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/torch.h>
2
+
3
+ #import <Foundation/Foundation.h>
4
+ #import <Metal/Metal.h>
5
+
6
+ #ifdef EMBEDDED_METALLIB_HEADER
7
+ #include EMBEDDED_METALLIB_HEADER
8
+ #else
9
+ #error "EMBEDDED_METALLIB_HEADER not defined"
10
+ #endif
11
+
12
+ static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor &tensor) {
13
+ return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
14
+ }
15
+
16
+ void first_kernel(torch::Tensor &out, torch::Tensor const &input) {
17
+ TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor");
18
+ TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
19
+ TORCH_CHECK(input.scalar_type() == at::ScalarType::Float,
20
+ "first_kernel only supports float32");
21
+ TORCH_CHECK(input.sizes() == out.sizes(), "Tensors must have same shape");
22
+ TORCH_CHECK(input.scalar_type() == out.scalar_type(),
23
+ "Tensors must have same dtype");
24
+ TORCH_CHECK(input.device() == out.device(),
25
+ "Tensors must be on same device");
26
+
27
+ @autoreleasepool {
28
+ id<MTLDevice> device = MTLCreateSystemDefaultDevice();
29
+ int numThreads = input.numel();
30
+
31
+ NSError *error = nil;
32
+ id<MTLLibrary> library =
33
+ EMBEDDED_METALLIB_NAMESPACE::createLibrary(device, &error);
34
+ TORCH_CHECK(library, "Failed to create Metal library: ",
35
+ error.localizedDescription.UTF8String);
36
+
37
+ id<MTLFunction> func =
38
+ [library newFunctionWithName:@"first_kernel_kernel"];
39
+ TORCH_CHECK(func, "Failed to create function");
40
+
41
+ id<MTLComputePipelineState> pso =
42
+ [device newComputePipelineStateWithFunction:func error:&error];
43
+ TORCH_CHECK(pso, error.localizedDescription.UTF8String);
44
+
45
+ id<MTLCommandBuffer> cmdBuf = torch::mps::get_command_buffer();
46
+ dispatch_sync(torch::mps::get_dispatch_queue(), ^() {
47
+ id<MTLComputeCommandEncoder> encoder = [cmdBuf computeCommandEncoder];
48
+ [encoder setComputePipelineState:pso];
49
+ [encoder setBuffer:getMTLBufferStorage(input)
50
+ offset:input.storage_offset() * input.element_size()
51
+ atIndex:0];
52
+ [encoder setBuffer:getMTLBufferStorage(out)
53
+ offset:out.storage_offset() * out.element_size()
54
+ atIndex:1];
55
+
56
+ NSUInteger tgSize =
57
+ MIN(pso.maxTotalThreadsPerThreadgroup, (NSUInteger)numThreads);
58
+ [encoder dispatchThreads:MTLSizeMake(numThreads, 1, 1)
59
+ threadsPerThreadgroup:MTLSizeMake(tgSize, 1, 1)];
60
+ [encoder endEncoding];
61
+ torch::mps::commit();
62
+ });
63
+ }
64
+ }
flake.lock ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1765121682,
6
+ "narHash": "sha256-4VBOP18BFeiPkyhy9o4ssBNQEvfvv1kXkasAYd0+rrA=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "65f23138d8d09a92e30f1e5c87611b23ef451bf3",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-utils": {
19
+ "inputs": {
20
+ "systems": "systems"
21
+ },
22
+ "locked": {
23
+ "lastModified": 1731533236,
24
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
25
+ "owner": "numtide",
26
+ "repo": "flake-utils",
27
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
28
+ "type": "github"
29
+ },
30
+ "original": {
31
+ "owner": "numtide",
32
+ "repo": "flake-utils",
33
+ "type": "github"
34
+ }
35
+ },
36
+ "kernel-builder": {
37
+ "inputs": {
38
+ "flake-compat": "flake-compat",
39
+ "flake-utils": "flake-utils",
40
+ "nixpkgs": "nixpkgs",
41
+ "rust-overlay": "rust-overlay"
42
+ },
43
+ "locked": {
44
+ "lastModified": 1774018498,
45
+ "narHash": "sha256-enigJmSw6g6e7PjsQ9z8aaMJJaSUVEOpOHsKulWhaSs=",
46
+ "owner": "huggingface",
47
+ "repo": "kernels",
48
+ "rev": "efe2480951107f1880a59cf1b5ae364b5d861566",
49
+ "type": "github"
50
+ },
51
+ "original": {
52
+ "owner": "huggingface",
53
+ "repo": "kernels",
54
+ "type": "github"
55
+ }
56
+ },
57
+ "nixpkgs": {
58
+ "locked": {
59
+ "lastModified": 1766341660,
60
+ "narHash": "sha256-4yG6vx7Dddk9/zh45Y2KM82OaRD4jO3HA9r98ORzysA=",
61
+ "owner": "NixOS",
62
+ "repo": "nixpkgs",
63
+ "rev": "26861f5606e3e4d1400771b513cc63e5f70151a6",
64
+ "type": "github"
65
+ },
66
+ "original": {
67
+ "owner": "NixOS",
68
+ "ref": "nixos-unstable-small",
69
+ "repo": "nixpkgs",
70
+ "type": "github"
71
+ }
72
+ },
73
+ "root": {
74
+ "inputs": {
75
+ "kernel-builder": "kernel-builder"
76
+ }
77
+ },
78
+ "rust-overlay": {
79
+ "inputs": {
80
+ "nixpkgs": [
81
+ "kernel-builder",
82
+ "nixpkgs"
83
+ ]
84
+ },
85
+ "locked": {
86
+ "lastModified": 1769050281,
87
+ "narHash": "sha256-1H8DN4UZgEUqPUA5ecHOufLZMscJ4IlcGaEftaPtpBY=",
88
+ "owner": "oxalica",
89
+ "repo": "rust-overlay",
90
+ "rev": "6deef0585c52d9e70f96b6121207e1496d4b0c49",
91
+ "type": "github"
92
+ },
93
+ "original": {
94
+ "owner": "oxalica",
95
+ "repo": "rust-overlay",
96
+ "type": "github"
97
+ }
98
+ },
99
+ "systems": {
100
+ "locked": {
101
+ "lastModified": 1681028828,
102
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
103
+ "owner": "nix-systems",
104
+ "repo": "default",
105
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "nix-systems",
110
+ "repo": "default",
111
+ "type": "github"
112
+ }
113
+ }
114
+ },
115
+ "root": "root",
116
+ "version": 7
117
+ }
flake.nix ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ inputs = {
3
+ kernel-builder.url = "github:huggingface/kernels";
4
+ };
5
+ outputs =
6
+ { self, kernel-builder, ... }:
7
+ kernel-builder.lib.genKernelFlakeOutputs {
8
+ inherit self;
9
+ path = ./.;
10
+ };
11
+ }
tests/__init__.py ADDED
File without changes
tests/test_first_kernel.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+
3
+ import torch
4
+
5
+ import first_kernel
6
+
7
+
8
+ def test_first_kernel():
9
+ if platform.system() == "Darwin":
10
+ device = torch.device("mps")
11
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
12
+ device = torch.device("xpu")
13
+ elif torch.version.cuda is not None and torch.cuda.is_available():
14
+ device = torch.device("cuda")
15
+ else:
16
+ device = torch.device("cpu")
17
+
18
+ x = torch.randn(1024, 1024, dtype=torch.float32, device=device)
19
+ expected = x + 1.0
20
+ result = first_kernel.first_kernel(x)
21
+ torch.testing.assert_close(result, expected)
torch-ext/first_kernel/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ def first_kernel(x: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
9
+ if out is None:
10
+ out = torch.empty_like(x)
11
+ ops.first_kernel(out, x)
12
+ return out
torch-ext/torch_binding.cpp ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/library.h>
2
+
3
+ #include "registration.h"
4
+ #include "torch_binding.h"
5
+
6
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
+ ops.def("first_kernel(Tensor! out, Tensor input) -> ()");
8
+ #if defined(METAL_KERNEL)
9
+ ops.impl("first_kernel", torch::kMPS, first_kernel);
10
+ #endif
11
+ }
12
+
13
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/torch.h>
4
+
5
+ void first_kernel(torch::Tensor &out, torch::Tensor const &input);