reach-vb HF staff commited on
Commit
42a2b88
1 Parent(s): c125e70

18256559666a70d7f16bb789b379220578acb9d0f204be811e183cc0a99d467b

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/defines.h +16 -0
  3. lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/erf.h +70 -0
  4. lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/gemm/conv.h +481 -0
  5. lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/gemm/gemm.h +538 -0
  6. lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/reduce.h +176 -0
  7. lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/utils.h +246 -0
  8. lib/python3.11/site-packages/mlx/include/mlx/backend/metal/matmul.h +31 -0
  9. lib/python3.11/site-packages/mlx/include/mlx/backend/metal/metal.h +31 -0
  10. lib/python3.11/site-packages/mlx/include/mlx/backend/metal/mps/gemm.h +370 -0
  11. lib/python3.11/site-packages/mlx/include/mlx/backend/metal/utils.h +169 -0
  12. lib/python3.11/site-packages/mlx/include/mlx/device.h +29 -0
  13. lib/python3.11/site-packages/mlx/include/mlx/dtype.h +105 -0
  14. lib/python3.11/site-packages/mlx/include/mlx/fft.h +151 -0
  15. lib/python3.11/site-packages/mlx/include/mlx/graph_utils.h +23 -0
  16. lib/python3.11/site-packages/mlx/include/mlx/io/load.h +114 -0
  17. lib/python3.11/site-packages/mlx/include/mlx/io/safetensor.h +32 -0
  18. lib/python3.11/site-packages/mlx/include/mlx/linalg.h +63 -0
  19. lib/python3.11/site-packages/mlx/include/mlx/mlx.h +14 -0
  20. lib/python3.11/site-packages/mlx/include/mlx/ops.h +1094 -0
  21. lib/python3.11/site-packages/mlx/include/mlx/primitives.h +1636 -0
  22. lib/python3.11/site-packages/mlx/include/mlx/random.h +193 -0
  23. lib/python3.11/site-packages/mlx/include/mlx/scheduler.h +173 -0
  24. lib/python3.11/site-packages/mlx/include/mlx/stream.h +32 -0
  25. lib/python3.11/site-packages/mlx/include/mlx/transforms.h +187 -0
  26. lib/python3.11/site-packages/mlx/include/mlx/transforms_impl.h +17 -0
  27. lib/python3.11/site-packages/mlx/include/mlx/types/bf16.h +187 -0
  28. lib/python3.11/site-packages/mlx/include/mlx/types/complex.h +77 -0
  29. lib/python3.11/site-packages/mlx/include/mlx/types/fp16.h +234 -0
  30. lib/python3.11/site-packages/mlx/include/mlx/types/half_types.h +56 -0
  31. lib/python3.11/site-packages/mlx/include/mlx/utils.h +44 -0
  32. lib/python3.11/site-packages/mlx/lib/libmlx.dylib +3 -0
  33. lib/python3.11/site-packages/mlx/lib/mlx.metallib +3 -0
  34. lib/python3.11/site-packages/mlx/nn/__init__.py +5 -0
  35. lib/python3.11/site-packages/mlx/nn/__pycache__/__init__.cpython-311.pyc +0 -0
  36. lib/python3.11/site-packages/mlx/nn/__pycache__/losses.cpython-311.pyc +0 -0
  37. lib/python3.11/site-packages/mlx/nn/__pycache__/utils.cpython-311.pyc +0 -0
  38. lib/python3.11/site-packages/mlx/nn/layers/__init__.py +63 -0
  39. lib/python3.11/site-packages/mlx/nn/layers/__pycache__/__init__.cpython-311.pyc +0 -0
  40. lib/python3.11/site-packages/mlx/nn/layers/__pycache__/activations.cpython-311.pyc +0 -0
  41. lib/python3.11/site-packages/mlx/nn/layers/__pycache__/base.cpython-311.pyc +0 -0
  42. lib/python3.11/site-packages/mlx/nn/layers/__pycache__/containers.cpython-311.pyc +0 -0
  43. lib/python3.11/site-packages/mlx/nn/layers/__pycache__/convolution.cpython-311.pyc +0 -0
  44. lib/python3.11/site-packages/mlx/nn/layers/__pycache__/dropout.cpython-311.pyc +0 -0
  45. lib/python3.11/site-packages/mlx/nn/layers/__pycache__/embedding.cpython-311.pyc +0 -0
  46. lib/python3.11/site-packages/mlx/nn/layers/__pycache__/linear.cpython-311.pyc +0 -0
  47. lib/python3.11/site-packages/mlx/nn/layers/__pycache__/normalization.cpython-311.pyc +0 -0
  48. lib/python3.11/site-packages/mlx/nn/layers/__pycache__/positional_encoding.cpython-311.pyc +0 -0
  49. lib/python3.11/site-packages/mlx/nn/layers/__pycache__/quantized.cpython-311.pyc +0 -0
  50. lib/python3.11/site-packages/mlx/nn/layers/__pycache__/transformer.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -35,3 +35,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  lib/python3.11/site-packages/llvmlite/binding/libllvmlite.dylib filter=lfs diff=lfs merge=lfs -text
37
  lib/python3.11/site-packages/mlx/core.cpython-311-darwin.so filter=lfs diff=lfs merge=lfs -text
 
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  lib/python3.11/site-packages/llvmlite/binding/libllvmlite.dylib filter=lfs diff=lfs merge=lfs -text
37
  lib/python3.11/site-packages/mlx/core.cpython-311-darwin.so filter=lfs diff=lfs merge=lfs -text
38
+ lib/python3.11/site-packages/mlx/lib/libmlx.dylib filter=lfs diff=lfs merge=lfs -text
39
+ lib/python3.11/site-packages/mlx/lib/mlx.metallib filter=lfs diff=lfs merge=lfs -text
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/defines.h ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #ifdef __METAL__
6
+ #define MTL_CONST constant
7
+ #else
8
+ #define MTL_CONST
9
+ #endif
10
+
11
+ static MTL_CONST constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
12
+ static MTL_CONST constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
13
+ static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
14
+ static MTL_CONST constexpr int REDUCE_N_READS = 16;
15
+ static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
16
+ static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/erf.h ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <metal_math>
6
+
7
+ /*
8
+ * Approximation to the error function.
9
+ * Based on code from:
10
+ * https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199
11
+ */
12
+ float erf(float a) {
13
+ float r, s, t, u;
14
+ t = metal::abs(a);
15
+ s = a * a;
16
+ if (t > 0.927734375f) {
17
+ // maximum error 0.99527 ulp
18
+ r = metal::fma(
19
+ -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
20
+ u = metal::fma(
21
+ -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
22
+ r = metal::fma(r, s, u);
23
+ r = metal::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
24
+ r = metal::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
25
+ r = metal::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
26
+ r = metal::fma(r, t, -t);
27
+ // TODO, replace with expm1 when implemented
28
+ r = 1.0f - metal::exp(r);
29
+ r = metal::copysign(r, a);
30
+ } else {
31
+ // maximum error 0.98929 ulp
32
+ r = -5.96761703e-4f; // -0x1.38e000p-11
33
+ r = metal::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
34
+ r = metal::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
35
+ r = metal::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
36
+ r = metal::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
37
+ r = metal::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
38
+ r = metal::fma(r, a, a);
39
+ }
40
+ return r;
41
+ }
42
+
43
+ float erfinv(float a) {
44
+ auto t = metal::fma(a, 0.0f - a, 1.0f);
45
+ t = metal::log(t);
46
+ float p;
47
+ if (metal::abs(t) > 6.125f) { // maximum ulp error = 2.35793
48
+ p = 3.03697567e-10f; // 0x1.4deb44p-32
49
+ p = metal::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
50
+ p = metal::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
51
+ p = metal::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
52
+ p = metal::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
53
+ p = metal::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
54
+ p = metal::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
55
+ p = metal::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
56
+ p = metal::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
57
+ } else { // maximum ulp error = 2.35002
58
+ p = 5.43877832e-9f; // 0x1.75c000p-28
59
+ p = metal::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
60
+ p = metal::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
61
+ p = metal::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
62
+ p = metal::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
63
+ p = metal::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
64
+ p = metal::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
65
+ p = metal::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
66
+ p = metal::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
67
+ p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
68
+ }
69
+ return a * p;
70
+ }
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/gemm/conv.h ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <metal_simdgroup>
6
+ #include <metal_simdgroup_matrix>
7
+ #include <metal_stdlib>
8
+
9
+ #include "mlx/backend/metal/kernels/bf16.h"
10
+ #include "mlx/backend/metal/kernels/conv_params.h"
11
+
12
+ #define MLX_MTL_CONST static constant constexpr const
13
+
14
+ using namespace metal;
15
+
16
+ ///////////////////////////////////////////////////////////////////////////////
17
+ // Loading helper
18
+ ///////////////////////////////////////////////////////////////////////////////
19
+
20
+ template <
21
+ typename T,
22
+ int BM,
23
+ int BN,
24
+ int BK,
25
+ int vec_size,
26
+ int tgp_size,
27
+ int tgp_padding = 0>
28
+ struct Conv2DInputBlockLoader {
29
+ // Destination dimensions
30
+ MLX_MTL_CONST int dst_fd = BM;
31
+ MLX_MTL_CONST int dst_ld = BK + tgp_padding;
32
+ MLX_MTL_CONST int n_vecs = BK / vec_size;
33
+
34
+ // Stride along block row within the block
35
+ MLX_MTL_CONST int bstride = tgp_size / n_vecs;
36
+ MLX_MTL_CONST int n_rows = dst_fd / bstride;
37
+
38
+ // Thread location indices
39
+ const short thread_idx;
40
+ const short bi;
41
+ const short bj;
42
+
43
+ // threadgroup and device memory
44
+ threadgroup T* dst;
45
+ const device T* src;
46
+
47
+ const constant MLXConvParams<2>& params;
48
+
49
+ int weight_h;
50
+ int weight_w;
51
+
52
+ int offsets_n[n_rows];
53
+ int offsets_oh[n_rows];
54
+ int offsets_ow[n_rows];
55
+
56
+ /* Constructor */
57
+ METAL_FUNC Conv2DInputBlockLoader(
58
+ const device T* src_,
59
+ threadgroup T* dst_,
60
+ const constant MLXConvParams<2>& params_,
61
+ uint3 tid [[threadgroup_position_in_grid]],
62
+ uint3 lid [[thread_position_in_threadgroup]],
63
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
64
+ uint simd_lane_id [[thread_index_in_simdgroup]])
65
+ : thread_idx(simd_group_id * 32 + simd_lane_id),
66
+ bi(thread_idx / n_vecs),
67
+ bj(vec_size * (thread_idx % n_vecs)),
68
+ dst(dst_ + bi * dst_ld + bj),
69
+ src(src_ + bj),
70
+ params(params_),
71
+ weight_h(0),
72
+ weight_w(0) {
73
+ int out_n_pixels = params.oS[0] * params.oS[1];
74
+
75
+ for (int i = 0; i < n_rows; ++i) {
76
+ int offset_nhw = tid.y * BM + bi + i * bstride;
77
+ offsets_n[i] = offset_nhw / out_n_pixels;
78
+ int hw = offset_nhw % out_n_pixels;
79
+ offsets_oh[i] = hw / params.oS[1];
80
+ offsets_ow[i] = hw % params.oS[1];
81
+ }
82
+
83
+ (void)lid;
84
+ }
85
+
86
+ /* Load from device memory into threadgroup memory - without bound checking */
87
+ METAL_FUNC void load_unsafe() const {
88
+ #pragma clang loop unroll(full)
89
+ for (short i = 0, is = 0; i < n_rows; ++i, is += bstride) {
90
+ int n = offsets_n[i];
91
+ int oh = offsets_oh[i];
92
+ int ow = offsets_ow[i];
93
+
94
+ int ih = oh * params.str[0] - params.pad[0] + weight_h * params.dil[0];
95
+ int iw = ow * params.str[1] - params.pad[1] + weight_w * params.dil[1];
96
+
97
+ // Read from input if in bounds
98
+ if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {
99
+ const device T* curr_src = src + n * params.in_strides[0] +
100
+ ih * params.in_strides[1] + iw * params.in_strides[2];
101
+
102
+ #pragma clang loop unroll(full)
103
+ for (short j = 0; j < vec_size; ++j) {
104
+ dst[is * dst_ld + j] = curr_src[j];
105
+ }
106
+ }
107
+
108
+ // Zero pad otherwise
109
+ else {
110
+ #pragma clang loop unroll(full)
111
+ for (short j = 0; j < vec_size; ++j) {
112
+ dst[is * dst_ld + j] = T(0);
113
+ }
114
+ }
115
+ }
116
+ }
117
+
118
+ /* Iteration helper */
119
+ METAL_FUNC void next() {
120
+ if (++weight_w < params.wS[1]) {
121
+ return;
122
+ }
123
+
124
+ weight_w = 0;
125
+
126
+ if (++weight_h < params.wS[0]) {
127
+ return;
128
+ }
129
+
130
+ weight_h = 0;
131
+
132
+ src += BK;
133
+ }
134
+ };
135
+
136
+ template <
137
+ typename T,
138
+ int BM,
139
+ int BN,
140
+ int BK,
141
+ int vec_size,
142
+ int tgp_size,
143
+ int tgp_padding = 0>
144
+ struct Conv2DWeightBlockLoader {
145
+ // Destination dimensions
146
+ MLX_MTL_CONST int dst_fd = BN;
147
+ MLX_MTL_CONST int dst_ld = BK + tgp_padding;
148
+ MLX_MTL_CONST int n_vecs = BK / vec_size;
149
+
150
+ // Stride along block row within the block
151
+ MLX_MTL_CONST int bstride = tgp_size / n_vecs;
152
+ MLX_MTL_CONST int n_rows = dst_fd / bstride;
153
+
154
+ // Leading dimension for src
155
+ const int src_ld;
156
+
157
+ // Thread location indices
158
+ const short thread_idx;
159
+ const short bi;
160
+ const short bj;
161
+
162
+ // threadgroup and device memory
163
+ threadgroup T* dst;
164
+ const device T* src;
165
+
166
+ const constant MLXConvParams<2>& params;
167
+
168
+ int weight_h;
169
+ int weight_w;
170
+
171
+ /* Constructor */
172
+ METAL_FUNC Conv2DWeightBlockLoader(
173
+ const device T* src_,
174
+ threadgroup T* dst_,
175
+ const constant MLXConvParams<2>& params_,
176
+ uint3 tid [[threadgroup_position_in_grid]],
177
+ uint3 lid [[thread_position_in_threadgroup]],
178
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
179
+ uint simd_lane_id [[thread_index_in_simdgroup]])
180
+ : src_ld(params_.wt_strides[0]),
181
+ thread_idx(simd_group_id * 32 + simd_lane_id),
182
+ bi(thread_idx / n_vecs),
183
+ bj(vec_size * (thread_idx % n_vecs)),
184
+ dst(dst_ + bi * dst_ld + bj),
185
+ src(src_ + bi * src_ld + bj),
186
+ params(params_),
187
+ weight_h(0),
188
+ weight_w(0) {
189
+ (void)lid;
190
+ (void)tid;
191
+ }
192
+
193
+ /* Load from device memory into threadgroup memory - without bound checking */
194
+ METAL_FUNC void load_unsafe() const {
195
+ const device T* curr_src =
196
+ src + weight_h * params.wt_strides[1] + weight_w * params.wt_strides[2];
197
+ #pragma clang loop unroll(full)
198
+ for (short i = 0; i < dst_fd; i += bstride) {
199
+ #pragma clang loop unroll(full)
200
+ for (short j = 0; j < vec_size; j++) {
201
+ dst[i * dst_ld + j] = curr_src[i * src_ld + j];
202
+ }
203
+ }
204
+ }
205
+
206
+ /* Iteration helper */
207
+ METAL_FUNC void next() {
208
+ if (++weight_w < params.wS[1]) {
209
+ return;
210
+ }
211
+
212
+ weight_w = 0;
213
+
214
+ if (++weight_h < params.wS[0]) {
215
+ return;
216
+ }
217
+
218
+ weight_h = 0;
219
+
220
+ src += BK;
221
+ }
222
+ };
223
+
224
+ ///////////////////////////////////////////////////////////////////////////////
225
+ // Transforms
226
+ ///////////////////////////////////////////////////////////////////////////////
227
+
228
+ template <typename OutT, typename InT>
229
+ struct TransformNone {
230
+ static METAL_FUNC OutT apply(InT x) {
231
+ return static_cast<OutT>(x);
232
+ }
233
+ };
234
+
235
+ template <typename T>
236
+ struct AccumHelper {
237
+ typedef float accum_type;
238
+ };
239
+
240
+ ///////////////////////////////////////////////////////////////////////////////
241
+ // MMA helper
242
+ ///////////////////////////////////////////////////////////////////////////////
243
+
244
+ template <
245
+ typename T,
246
+ int BM,
247
+ int BN,
248
+ int BK,
249
+ int WM,
250
+ int WN,
251
+ bool transpose_a,
252
+ bool transpose_b,
253
+ int tgp_padding_a = 0,
254
+ int tgp_padding_b = 0,
255
+ typename AccumType = typename AccumHelper<T>::accum_type,
256
+ typename Epilogue = TransformNone<T, AccumType>>
257
+ struct Conv2DBlockMMA {
258
+ // Warp tile size along M
259
+ MLX_MTL_CONST int TM = BM / (WM * 8);
260
+ // Warp tile size along N
261
+ MLX_MTL_CONST int TN = BN / (WN * 8);
262
+
263
+ // Warp tile simdgroup matrix strides along M
264
+ MLX_MTL_CONST int TM_stride = 8 * WM;
265
+ // Warp tile simdgroup matrix strides along M
266
+ MLX_MTL_CONST int TN_stride = 8 * WN;
267
+
268
+ // Leading dimensions of threadgroup A, B blocks
269
+ MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
270
+ MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
271
+
272
+ // Strides of A, B along reduction axis
273
+ MLX_MTL_CONST short simd_stride_a =
274
+ transpose_a ? TM_stride : TM_stride * lda_tgp;
275
+ MLX_MTL_CONST short simd_stride_b =
276
+ transpose_b ? TN_stride * ldb_tgp : TN_stride;
277
+
278
+ // Jump between elements
279
+ MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
280
+ MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
281
+
282
+ // Offsets within threadgroup
283
+ const int tm;
284
+ const int tn;
285
+
286
+ // Simdgroup matrices
287
+ simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
288
+ simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
289
+ simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
290
+ simdgroup_matrix<AccumType, 8, 8>(0)};
291
+
292
+ short sm;
293
+ short sn;
294
+
295
+ /* Constructor */
296
+ METAL_FUNC Conv2DBlockMMA(
297
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
298
+ uint simd_lane_id [[thread_index_in_simdgroup]])
299
+ : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
300
+ short qid = simd_lane_id / 4;
301
+ sm = (qid & 4) + (simd_lane_id / 2) % 4;
302
+ sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
303
+ }
304
+
305
+ /* (BM, BK) X (BK, BN) multiply accumulate function */
306
+ METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
307
+ // Iterate over BK in blocks of 8
308
+ #pragma clang loop unroll(full)
309
+ for (short kk = 0; kk < BK; kk += 8) {
310
+ short2 offset_a =
311
+ transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
312
+ short2 offset_b =
313
+ transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
314
+
315
+ const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
316
+ const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
317
+
318
+ simdgroup_barrier(mem_flags::mem_none);
319
+ // Load elements from threadgroup A as simdgroup matrices
320
+ #pragma clang loop unroll(full)
321
+ for (short i = 0; i < TM; i++) {
322
+ Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
323
+ Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
324
+ As__ += simd_stride_a;
325
+ }
326
+
327
+ simdgroup_barrier(mem_flags::mem_none);
328
+ // Load elements from threadgroup B as simdgroup matrices
329
+ #pragma clang loop unroll(full)
330
+ for (short j = 0; j < TN; j++) {
331
+ Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
332
+ Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
333
+ Bs__ += simd_stride_b;
334
+ }
335
+
336
+ simdgroup_barrier(mem_flags::mem_none);
337
+ // Multiply and accumulate into result simdgroup matrices
338
+ #pragma clang loop unroll(full)
339
+ for (short i = 0; i < TM; i++) {
340
+ #pragma clang loop unroll(full)
341
+ for (short j = 0; j < TN; j++) {
342
+ simdgroup_multiply_accumulate(
343
+ results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
344
+ }
345
+ }
346
+ }
347
+ }
348
+
349
+ /* Store results from simdgroup_matrix results into device memory */
350
+ METAL_FUNC void store_result(device T* C, const int ldc) const {
351
+ #pragma clang loop unroll(full)
352
+ for (int i = 0; i < TM; i++) {
353
+ #pragma clang loop unroll(full)
354
+ for (int j = 0; j < TN; j++) {
355
+ C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
356
+ Epilogue::apply(results[i * TN + j].thread_elements()[0]);
357
+ C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
358
+ Epilogue::apply(results[i * TN + j].thread_elements()[1]);
359
+ }
360
+ }
361
+ }
362
+
363
+ METAL_FUNC void
364
+ store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
365
+ #pragma clang loop unroll(full)
366
+ for (int i = 0; i < TM; i++) {
367
+ if (tm + i * TM_stride + sm < dst_tile_dims.y) {
368
+ #pragma clang loop unroll(full)
369
+ for (int j = 0; j < TN; j++) {
370
+ if (tn + j * TN_stride + sn < dst_tile_dims.x) {
371
+ C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
372
+ Epilogue::apply(results[i * TN + j].thread_elements()[0]);
373
+ }
374
+
375
+ if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
376
+ C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
377
+ Epilogue::apply(results[i * TN + j].thread_elements()[1]);
378
+ }
379
+ }
380
+ }
381
+ }
382
+ }
383
+ };
384
+
385
+ ///////////////////////////////////////////////////////////////////////////////
386
+ // GEMM kernels
387
+ ///////////////////////////////////////////////////////////////////////////////
388
+
389
+ template <
390
+ typename T,
391
+ int BM,
392
+ int BN,
393
+ int BK,
394
+ int WM,
395
+ int WN,
396
+ bool transpose_a,
397
+ bool transpose_b,
398
+ typename AccumType = typename AccumHelper<T>::accum_type,
399
+ typename Epilogue = TransformNone<T, AccumType>>
400
+ struct Conv2DImplicitGEMMKernel {
401
+ MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
402
+ MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
403
+ MLX_MTL_CONST short tgp_mem_size_a =
404
+ transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
405
+ MLX_MTL_CONST short tgp_mem_size_b =
406
+ transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
407
+ MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
408
+
409
+ MLX_MTL_CONST short tgp_size = WM * WN * 32;
410
+ MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
411
+
412
+ using loader_a_t =
413
+ Conv2DInputBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_a>;
414
+ using loader_b_t =
415
+ Conv2DWeightBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_b>;
416
+ using mma_t = Conv2DBlockMMA<
417
+ T,
418
+ BM,
419
+ BN,
420
+ BK,
421
+ WM,
422
+ WN,
423
+ transpose_a,
424
+ transpose_b,
425
+ tgp_padding_a,
426
+ tgp_padding_b,
427
+ AccumType,
428
+ Epilogue>;
429
+
430
+ /* Main kernel function */
431
+ static METAL_FUNC void run(
432
+ const device T* A [[buffer(0)]],
433
+ const device T* B [[buffer(1)]],
434
+ device T* C [[buffer(2)]],
435
+ const constant MLXConvParams<2>& params [[buffer(3)]],
436
+ threadgroup T* tgp_memory [[threadgroup(0)]],
437
+ uint3 tid [[threadgroup_position_in_grid]],
438
+ uint3 lid [[thread_position_in_threadgroup]],
439
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
440
+ uint simd_lid [[thread_index_in_simdgroup]]) {
441
+ const int c_row = tid.y * BM;
442
+ const int c_col = tid.x * BN;
443
+ const int K = params.wt_strides[0];
444
+ const int N = params.O;
445
+
446
+ B += c_col * K;
447
+ C += c_row * N + c_col;
448
+
449
+ // Prepare threadgroup memory for loading
450
+ threadgroup T* As = tgp_memory;
451
+ threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
452
+
453
+ // Prepare threadgroup loading operations
454
+ loader_a_t loader_a(A, As, params, tid, lid, simd_gid, simd_lid);
455
+ loader_b_t loader_b(B, Bs, params, tid, lid, simd_gid, simd_lid);
456
+
457
+ // Prepare threadgroup mma operation
458
+ mma_t mma_op(simd_gid, simd_lid);
459
+
460
+ for (int k = 0; k < K; k += BK) {
461
+ threadgroup_barrier(mem_flags::mem_threadgroup);
462
+ // Load elements into threadgroup
463
+ loader_a.load_unsafe();
464
+ loader_b.load_unsafe();
465
+
466
+ threadgroup_barrier(mem_flags::mem_threadgroup);
467
+
468
+ // Multiply and accumulate threadgroup elements
469
+ mma_op.mma(As, Bs);
470
+
471
+ // Prepare for next iteration
472
+ loader_a.next();
473
+ loader_b.next();
474
+ }
475
+
476
+ threadgroup_barrier(mem_flags::mem_none);
477
+
478
+ // Store results to device memory
479
+ mma_op.store_result(C, N);
480
+ }
481
+ };
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/gemm/gemm.h ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <metal_simdgroup>
6
+ #include <metal_simdgroup_matrix>
7
+ #include <metal_stdlib>
8
+
9
+ #define MLX_MTL_CONST static constant constexpr const
10
+
11
+ using namespace metal;
12
+
13
+ ///////////////////////////////////////////////////////////////////////////////
14
+ // Loading helper
15
+ ///////////////////////////////////////////////////////////////////////////////
16
+
17
+ template <
18
+ typename T,
19
+ int BROWS,
20
+ int BCOLS,
21
+ int BK,
22
+ int vec_size,
23
+ int tgp_size,
24
+ bool transpose,
25
+ bool ldK,
26
+ int tgp_padding = 0>
27
+ struct BlockLoader {
28
+ // Destination dimensions
29
+ MLX_MTL_CONST int dst_fd = transpose ? BCOLS : BROWS;
30
+ MLX_MTL_CONST int dst_ld = (transpose ? BROWS : BCOLS) + tgp_padding;
31
+ MLX_MTL_CONST int n_vecs = (transpose ? BROWS : BCOLS) / vec_size;
32
+
33
+ // Stride along block row within the block
34
+ MLX_MTL_CONST int bstride = tgp_size / n_vecs;
35
+
36
+ // Leading dimension for src
37
+ const int src_ld;
38
+ // Stride along reduction axis between blocks
39
+ const int tstride;
40
+
41
+ // Thread location indices
42
+ const short thread_idx;
43
+ const short bi;
44
+ const short bj;
45
+
46
+ // threadgroup and device memory
47
+ threadgroup T* dst;
48
+ const device T* src;
49
+
50
+ /* Constructor */
51
+ METAL_FUNC BlockLoader(
52
+ const device T* src_,
53
+ const int src_ld_,
54
+ threadgroup T* dst_,
55
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
56
+ uint simd_lane_id [[thread_index_in_simdgroup]])
57
+ : src_ld(src_ld_),
58
+ tstride(
59
+ BK * ((int)(transpose ^ !ldK) * src_ld + (int)(transpose ^ ldK))),
60
+ thread_idx(simd_group_id * 32 + simd_lane_id),
61
+ bi(thread_idx / n_vecs),
62
+ bj(vec_size * (thread_idx % n_vecs)),
63
+ dst(dst_ + bi * dst_ld + bj),
64
+ src(src_ + bi * src_ld + bj) {}
65
+
66
+ /* Load from device memory into threadgroup memory - without bound checking */
67
+ METAL_FUNC void load_unsafe() const {
68
+ #pragma clang loop unroll(full)
69
+ for (short i = 0; i < dst_fd; i += bstride) {
70
+ #pragma clang loop unroll(full)
71
+ for (short j = 0; j < vec_size; j++) {
72
+ dst[i * dst_ld + j] = src[i * src_ld + j];
73
+ }
74
+ }
75
+ }
76
+
77
+ /* Load from device memory into threadgroup memory - with bound checking */
78
+ METAL_FUNC void load_safe(short2 src_tile_dim) const {
79
+ src_tile_dim = transpose ? src_tile_dim.yx : src_tile_dim.xy;
80
+
81
+ // Iterate over rows of block
82
+ #pragma clang loop unroll(full)
83
+ for (short i = 0; i < dst_fd; i += bstride) {
84
+ // Row is in bounds, we check against column
85
+ if ((bi + i) < src_tile_dim.y) {
86
+ // Use fast thread memory for bound checks
87
+ short tmp_idx[vec_size];
88
+ T tmp_val[vec_size];
89
+
90
+ // Make sure tmp_idx only contains valid indices
91
+ #pragma clang loop unroll(full)
92
+ for (short j = 0; j < vec_size; j++) {
93
+ tmp_idx[j] = bj + j < src_tile_dim.x ? j : 0;
94
+ }
95
+
96
+ // Read all valid indices into tmp_val
97
+ #pragma clang loop unroll(full)
98
+ for (short j = 0; j < vec_size; j++) {
99
+ tmp_val[j] = src[i * src_ld + tmp_idx[j]];
100
+ }
101
+
102
+ // Zero out unneeded values
103
+ #pragma clang loop unroll(full)
104
+ for (short j = 0; j < vec_size; j++) {
105
+ tmp_val[j] = bj + j < src_tile_dim.x ? tmp_val[j] : T(0);
106
+ }
107
+
108
+ // Copy values to threadgroup memory
109
+ #pragma clang loop unroll(full)
110
+ for (short j = 0; j < vec_size; j++) {
111
+ dst[i * dst_ld + j] = tmp_val[j];
112
+ }
113
+ }
114
+
115
+ // Row is out of bounds, we just fill tgp memory with zeros
116
+ else {
117
+ #pragma clang loop unroll(full)
118
+ for (short j = 0; j < vec_size; j++) {
119
+ dst[i * dst_ld + j] = T(0);
120
+ }
121
+ }
122
+ }
123
+ }
124
+
125
+ /* Iteration helper */
126
+ METAL_FUNC void next() {
127
+ src += tstride;
128
+ }
129
+ };
130
+
131
+ ///////////////////////////////////////////////////////////////////////////////
132
+ // Transforms
133
+ ///////////////////////////////////////////////////////////////////////////////
134
+
135
+ template <typename OutT, typename InT>
136
+ struct TransformNone {
137
+ static METAL_FUNC OutT apply(InT x) {
138
+ return static_cast<OutT>(x);
139
+ }
140
+ };
141
+
142
+ template <typename T>
143
+ struct AccumHelper {
144
+ typedef float accum_type;
145
+ };
146
+
147
+ ///////////////////////////////////////////////////////////////////////////////
148
+ // MMA helper
149
+ ///////////////////////////////////////////////////////////////////////////////
150
+
151
+ template <
152
+ typename T,
153
+ int BM,
154
+ int BN,
155
+ int BK,
156
+ int WM,
157
+ int WN,
158
+ bool transpose_a,
159
+ bool transpose_b,
160
+ int tgp_padding_a = 0,
161
+ int tgp_padding_b = 0,
162
+ typename AccumType = typename AccumHelper<T>::accum_type,
163
+ typename Epilogue = TransformNone<T, AccumType>>
164
+ struct BlockMMA {
165
+ // Warp tile size along M
166
+ MLX_MTL_CONST int TM = BM / (WM * 8);
167
+ // Warp tile size along N
168
+ MLX_MTL_CONST int TN = BN / (WN * 8);
169
+
170
+ // Warp tile simdgroup matrix strides along M
171
+ MLX_MTL_CONST int TM_stride = 8 * WM;
172
+ // Warp tile simdgroup matrix strides along M
173
+ MLX_MTL_CONST int TN_stride = 8 * WN;
174
+
175
+ // Leading dimensions of threadgroup A, B blocks
176
+ MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
177
+ MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
178
+
179
+ // Strides of A, B along reduction axis
180
+ MLX_MTL_CONST short simd_stride_a =
181
+ transpose_a ? TM_stride : TM_stride * lda_tgp;
182
+ MLX_MTL_CONST short simd_stride_b =
183
+ transpose_b ? TN_stride * ldb_tgp : TN_stride;
184
+
185
+ // Jump between elements
186
+ MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
187
+ MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
188
+
189
+ // Offsets within threadgroup
190
+ const int tm;
191
+ const int tn;
192
+
193
+ // Simdgroup matrices
194
+ simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
195
+ simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
196
+ simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
197
+ simdgroup_matrix<AccumType, 8, 8>(0)};
198
+
199
+ short sm;
200
+ short sn;
201
+
202
+ /* Constructor */
203
+ METAL_FUNC BlockMMA(
204
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
205
+ uint simd_lane_id [[thread_index_in_simdgroup]])
206
+ : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
207
+ short qid = simd_lane_id / 4;
208
+ sm = (qid & 4) + (simd_lane_id / 2) % 4;
209
+ sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
210
+ }
211
+
212
+ /* (BM, BK) X (BK, BN) multiply accumulate function */
213
+ METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
214
+ // Iterate over BK in blocks of 8
215
+ #pragma clang loop unroll(full)
216
+ for (short kk = 0; kk < BK; kk += 8) {
217
+ short2 offset_a =
218
+ transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
219
+ short2 offset_b =
220
+ transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
221
+
222
+ const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
223
+ const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
224
+
225
+ simdgroup_barrier(mem_flags::mem_none);
226
+ // Load elements from threadgroup A as simdgroup matrices
227
+ #pragma clang loop unroll(full)
228
+ for (short i = 0; i < TM; i++) {
229
+ Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
230
+ Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
231
+ As__ += simd_stride_a;
232
+ }
233
+
234
+ simdgroup_barrier(mem_flags::mem_none);
235
+ // Load elements from threadgroup B as simdgroup matrices
236
+ #pragma clang loop unroll(full)
237
+ for (short j = 0; j < TN; j++) {
238
+ Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
239
+ Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
240
+ Bs__ += simd_stride_b;
241
+ }
242
+
243
+ simdgroup_barrier(mem_flags::mem_none);
244
+ // Multiply and accumulate into result simdgroup matrices
245
+ #pragma clang loop unroll(full)
246
+ for (short i = 0; i < TM; i++) {
247
+ #pragma clang loop unroll(full)
248
+ for (short j = 0; j < TN; j++) {
249
+ simdgroup_multiply_accumulate(
250
+ results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
251
+ }
252
+ }
253
+ }
254
+ }
255
+
256
+ /* Store results from simdgroup_matrix results into device memory */
257
+ METAL_FUNC void store_result(device T* C, const int ldc) const {
258
+ #pragma clang loop unroll(full)
259
+ for (int i = 0; i < TM; i++) {
260
+ #pragma clang loop unroll(full)
261
+ for (int j = 0; j < TN; j++) {
262
+ C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
263
+ Epilogue::apply(results[i * TN + j].thread_elements()[0]);
264
+ C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
265
+ Epilogue::apply(results[i * TN + j].thread_elements()[1]);
266
+ }
267
+ }
268
+ }
269
+
270
+ METAL_FUNC void
271
+ store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
272
+ #pragma clang loop unroll(full)
273
+ for (int i = 0; i < TM; i++) {
274
+ if (tm + i * TM_stride + sm < dst_tile_dims.y) {
275
+ #pragma clang loop unroll(full)
276
+ for (int j = 0; j < TN; j++) {
277
+ if (tn + j * TN_stride + sn < dst_tile_dims.x) {
278
+ C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
279
+ Epilogue::apply(results[i * TN + j].thread_elements()[0]);
280
+ }
281
+
282
+ if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
283
+ C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
284
+ Epilogue::apply(results[i * TN + j].thread_elements()[1]);
285
+ }
286
+ }
287
+ }
288
+ }
289
+ }
290
+ };
291
+
292
+ ///////////////////////////////////////////////////////////////////////////////
293
+ // GEMM kernels
294
+ ///////////////////////////////////////////////////////////////////////////////
295
+
296
+ template <
297
+ typename T,
298
+ int BM,
299
+ int BN,
300
+ int BK,
301
+ int WM,
302
+ int WN,
303
+ bool transpose_a,
304
+ bool transpose_b,
305
+ bool MN_aligned,
306
+ bool K_aligned,
307
+ typename AccumType = typename AccumHelper<T>::accum_type,
308
+ typename Epilogue = TransformNone<T, AccumType>>
309
+ struct GEMMKernel {
310
+ MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
311
+ MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
312
+ MLX_MTL_CONST short tgp_mem_size_a =
313
+ transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
314
+ MLX_MTL_CONST short tgp_mem_size_b =
315
+ transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
316
+ MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
317
+
318
+ MLX_MTL_CONST short tgp_size = WM * WN * 32;
319
+ MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
320
+
321
+ using loader_a_t = BlockLoader<
322
+ T,
323
+ BM,
324
+ BK,
325
+ BK,
326
+ vec_size,
327
+ tgp_size,
328
+ transpose_a,
329
+ true,
330
+ tgp_padding_a>;
331
+ using loader_b_t = BlockLoader<
332
+ T,
333
+ BK,
334
+ BN,
335
+ BK,
336
+ vec_size,
337
+ tgp_size,
338
+ transpose_b,
339
+ false,
340
+ tgp_padding_b>;
341
+ using mma_t = BlockMMA<
342
+ T,
343
+ BM,
344
+ BN,
345
+ BK,
346
+ WM,
347
+ WN,
348
+ transpose_a,
349
+ transpose_b,
350
+ tgp_padding_a,
351
+ tgp_padding_b,
352
+ AccumType,
353
+ Epilogue>;
354
+
355
+ /* Main kernel function */
356
+ static METAL_FUNC void run(
357
+ const device T* A [[buffer(0)]],
358
+ const device T* B [[buffer(1)]],
359
+ device T* C [[buffer(2)]],
360
+ const constant int& M [[buffer(3)]],
361
+ const constant int& N [[buffer(4)]],
362
+ const constant int& K [[buffer(5)]],
363
+ const constant int& batch_stride_a [[buffer(6)]],
364
+ const constant int& batch_stride_b [[buffer(7)]],
365
+ const constant int& batch_stride_c [[buffer(8)]],
366
+ threadgroup T* tgp_memory [[threadgroup(0)]],
367
+ uint simd_lane_id [[thread_index_in_simdgroup]],
368
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
369
+ uint3 tid [[threadgroup_position_in_grid]],
370
+ uint3 lid [[thread_position_in_threadgroup]]) {
371
+ // Pacifying compiler
372
+ (void)lid;
373
+
374
+ // Adjust for batch
375
+ A += batch_stride_a * tid.z;
376
+ B += batch_stride_b * tid.z;
377
+ C += batch_stride_c * tid.z;
378
+
379
+ // Adjust for transpose
380
+ const int lda_dev = transpose_a ? M : K;
381
+ const int ldb_dev = transpose_b ? K : N;
382
+
383
+ // Find block in A, B, C
384
+ const int c_row = tid.y * BM;
385
+ const int c_col = tid.x * BN;
386
+
387
+ A += transpose_a ? c_row : c_row * K;
388
+ B += transpose_b ? c_col * K : c_col;
389
+ C += c_row * N + c_col;
390
+
391
+ // Prepare threadgroup memory for loading
392
+ threadgroup T* As = tgp_memory;
393
+ threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
394
+
395
+ // Prepare threadgroup loading operations
396
+ loader_a_t loader_a(A, lda_dev, As, simd_group_id, simd_lane_id);
397
+ loader_b_t loader_b(B, ldb_dev, Bs, simd_group_id, simd_lane_id);
398
+
399
+ // Prepare threadgroup mma operation
400
+ mma_t mma_op(simd_group_id, simd_lane_id);
401
+
402
+ ///////////////////////////////////////////////////////////////////////////////
403
+ // MNK aligned loop
404
+ if (MN_aligned && K_aligned) {
405
+ for (int k = 0; k < K; k += BK) {
406
+ threadgroup_barrier(mem_flags::mem_threadgroup);
407
+ // Load elements into threadgroup
408
+ loader_a.load_unsafe();
409
+ loader_b.load_unsafe();
410
+
411
+ threadgroup_barrier(mem_flags::mem_threadgroup);
412
+
413
+ // Multiply and accumulate threadgroup elements
414
+ mma_op.mma(As, Bs);
415
+
416
+ // Prepare for next iteration
417
+ loader_a.next();
418
+ loader_b.next();
419
+ }
420
+
421
+ threadgroup_barrier(mem_flags::mem_none);
422
+
423
+ // Store results to device memory
424
+ mma_op.store_result(C, N);
425
+ return;
426
+
427
+ }
428
+ ///////////////////////////////////////////////////////////////////////////////
429
+ // MN aligned, K unaligned loop
430
+ else if (MN_aligned && !K_aligned) {
431
+ // Main loop
432
+ int k = 0;
433
+ for (; k + BK <= K; k += BK) {
434
+ threadgroup_barrier(mem_flags::mem_threadgroup);
435
+ // Load elements into threadgroup
436
+ loader_a.load_unsafe();
437
+ loader_b.load_unsafe();
438
+
439
+ threadgroup_barrier(mem_flags::mem_threadgroup);
440
+
441
+ // Multiply and accumulate threadgroup elements
442
+ mma_op.mma(As, Bs);
443
+
444
+ // Prepare for next iteration
445
+ loader_a.next();
446
+ loader_b.next();
447
+ }
448
+
449
+ // Loop tail
450
+ threadgroup_barrier(mem_flags::mem_threadgroup);
451
+
452
+ loader_a.load_safe(short2(K - k, BM));
453
+ loader_b.load_safe(short2(BN, K - k));
454
+
455
+ threadgroup_barrier(mem_flags::mem_threadgroup);
456
+
457
+ mma_op.mma(As, Bs);
458
+
459
+ // Store results to device memory
460
+ mma_op.store_result(C, N);
461
+ return;
462
+
463
+ }
464
+ ///////////////////////////////////////////////////////////////////////////////
465
+ // MNK unaligned loop
466
+ else { // Loop over K - unaligned case
467
+
468
+ short2 src_tile_dims(min(BN, N - c_col), min(BM, M - c_row));
469
+
470
+ if (src_tile_dims.y == BM && src_tile_dims.x == BN) {
471
+ int k = 0;
472
+ for (; k + BK <= K; k += BK) {
473
+ threadgroup_barrier(mem_flags::mem_threadgroup);
474
+ // Load elements into threadgroup
475
+ loader_a.load_unsafe();
476
+ loader_b.load_unsafe();
477
+
478
+ threadgroup_barrier(mem_flags::mem_threadgroup);
479
+
480
+ // Multiply and accumulate threadgroup elements
481
+ mma_op.mma(As, Bs);
482
+
483
+ // Prepare for next iteration
484
+ loader_a.next();
485
+ loader_b.next();
486
+ }
487
+
488
+ threadgroup_barrier(mem_flags::mem_none);
489
+
490
+ if (k < K) {
491
+ loader_a.load_safe(short2(K - k, BM));
492
+ loader_b.load_safe(short2(BN, K - k));
493
+
494
+ threadgroup_barrier(mem_flags::mem_threadgroup);
495
+
496
+ mma_op.mma(As, Bs);
497
+ }
498
+
499
+ mma_op.store_result(C, N);
500
+ return;
501
+
502
+ } else {
503
+ int k = 0;
504
+ for (; k + BK <= K; k += BK) {
505
+ threadgroup_barrier(mem_flags::mem_threadgroup);
506
+ // Load elements into threadgroup
507
+ loader_a.load_safe(short2(BK, src_tile_dims.y));
508
+ loader_b.load_safe(short2(src_tile_dims.x, BK));
509
+
510
+ threadgroup_barrier(mem_flags::mem_threadgroup);
511
+
512
+ // Multiply and accumulate threadgroup elements
513
+ mma_op.mma(As, Bs);
514
+
515
+ // Prepare for next iteration
516
+ loader_a.next();
517
+ loader_b.next();
518
+ }
519
+
520
+ threadgroup_barrier(mem_flags::mem_none);
521
+
522
+ if (k < K) {
523
+ loader_a.load_safe(short2(K - k, src_tile_dims.y));
524
+ loader_b.load_safe(short2(src_tile_dims.x, K - k));
525
+
526
+ threadgroup_barrier(mem_flags::mem_threadgroup);
527
+
528
+ mma_op.mma(As, Bs);
529
+ }
530
+
531
+ threadgroup_barrier(mem_flags::mem_none);
532
+ mma_op.store_result_safe(C, N, src_tile_dims);
533
+
534
+ return;
535
+ }
536
+ }
537
+ }
538
+ };
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/reduce.h ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <metal_atomic>
6
+ #include <metal_simdgroup>
7
+
8
+ #include "mlx/backend/metal/kernels/atomic.h"
9
+ #include "mlx/backend/metal/kernels/bf16.h"
10
+ #include "mlx/backend/metal/kernels/utils.h"
11
+
12
+ union bool4_or_uint {
13
+ bool4 b;
14
+ unsigned int i;
15
+ };
16
+
17
+ struct None {
18
+ template <typename T>
19
+ void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
20
+ mlx_atomic_store_explicit(out, val, offset);
21
+ }
22
+ };
23
+
24
+ struct And {
25
+ bool simd_reduce(bool val) {
26
+ return simd_all(val);
27
+ };
28
+
29
+ static constexpr constant bool init = true;
30
+
31
+ void atomic_update(
32
+ device mlx_atomic<unsigned int>* out,
33
+ bool val,
34
+ int elem_idx,
35
+ int offset = 0) {
36
+ if (!val) {
37
+ bool4_or_uint update;
38
+ update.b = {true, true, true, true};
39
+ update.b[elem_idx] = false;
40
+ mlx_atomic_fetch_and_explicit(out, update.i, offset);
41
+ }
42
+ }
43
+
44
+ void atomic_update(device mlx_atomic<bool>* out, bool val, int offset = 0) {
45
+ if (!val) {
46
+ mlx_atomic_store_explicit(out, val, offset);
47
+ }
48
+ }
49
+
50
+ // Non atomic update
51
+ void update(device bool* out, bool val) {
52
+ *out &= val;
53
+ }
54
+
55
+ // Operator
56
+ bool operator()(bool a, bool b) {
57
+ return a && b;
58
+ }
59
+ };
60
+
61
+ struct Or {
62
+ bool simd_reduce(bool val) {
63
+ return simd_any(val);
64
+ };
65
+
66
+ static constexpr constant bool init = false;
67
+
68
+ void atomic_update(
69
+ device mlx_atomic<unsigned int>* out,
70
+ bool val,
71
+ int elem_idx,
72
+ int offset = 0) {
73
+ if (val) {
74
+ bool4_or_uint update;
75
+ update.b = {false, false, false, false};
76
+ update.b[elem_idx] = true;
77
+ mlx_atomic_fetch_or_explicit(out, update.i, offset);
78
+ }
79
+ }
80
+
81
+ void atomic_update(device mlx_atomic<bool>* out, bool val, int offset = 0) {
82
+ if (val) {
83
+ mlx_atomic_store_explicit(out, val, offset);
84
+ }
85
+ }
86
+
87
+ // Non atomic update
88
+ void update(device bool* out, bool val) {
89
+ *out |= val;
90
+ }
91
+
92
+ // Operator
93
+ bool operator()(bool a, bool b) {
94
+ return a || b;
95
+ }
96
+ };
97
+
98
+ template <typename U>
99
+ struct Sum {
100
+ template <typename T>
101
+ T simd_reduce(T val) {
102
+ return simd_sum(val);
103
+ };
104
+
105
+ static constexpr constant U init = U(0);
106
+
107
+ template <typename T>
108
+ void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
109
+ mlx_atomic_fetch_add_explicit(out, val, offset);
110
+ }
111
+
112
+ // Operator
113
+ U operator()(U a, U b) {
114
+ return a + b;
115
+ }
116
+ };
117
+
118
+ template <typename U>
119
+ struct Prod {
120
+ template <typename T>
121
+ T simd_reduce(T val) {
122
+ return simd_product(val);
123
+ };
124
+
125
+ static constexpr constant U init = U(1);
126
+
127
+ template <typename T>
128
+ void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
129
+ mlx_atomic_fetch_mul_explicit(out, val, offset);
130
+ }
131
+
132
+ // Operator
133
+ U operator()(U a, U b) {
134
+ return a * b;
135
+ }
136
+ };
137
+
138
+ template <typename U>
139
+ struct Min {
140
+ template <typename T>
141
+ T simd_reduce(T val) {
142
+ return simd_min(val);
143
+ };
144
+
145
+ static constexpr constant U init = Limits<U>::max;
146
+
147
+ template <typename T>
148
+ void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
149
+ mlx_atomic_fetch_min_explicit(out, val, offset);
150
+ }
151
+
152
+ // Operator
153
+ U operator()(U a, U b) {
154
+ return a < b ? a : b;
155
+ }
156
+ };
157
+
158
+ template <typename U>
159
+ struct Max {
160
+ template <typename T>
161
+ T simd_reduce(T val) {
162
+ return simd_max(val);
163
+ };
164
+
165
+ static constexpr constant U init = Limits<U>::min;
166
+
167
+ template <typename T>
168
+ void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
169
+ mlx_atomic_fetch_max_explicit(out, val, offset);
170
+ }
171
+
172
+ // Operator
173
+ U operator()(U a, U b) {
174
+ return a > b ? a : b;
175
+ }
176
+ };
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/utils.h ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <metal_math>
6
+ #include "mlx/backend/metal/kernels/bf16.h"
7
+ #include "mlx/backend/metal/kernels/complex.h"
8
+
9
+ ///////////////////////////////////////////////////////////////////////////////
10
+ // Type limits utils
11
+ ///////////////////////////////////////////////////////////////////////////////
12
+
13
+ template <typename U>
14
+ struct Limits {
15
+ static const constant U max;
16
+ static const constant U min;
17
+ static const constant U finite_max;
18
+ static const constant U finite_min;
19
+ };
20
+
21
+ #define instantiate_default_limit(type) \
22
+ template <> \
23
+ struct Limits<type> { \
24
+ static constexpr constant type max = metal::numeric_limits<type>::max(); \
25
+ static constexpr constant type min = metal::numeric_limits<type>::min(); \
26
+ static constexpr constant type finite_max = \
27
+ metal::numeric_limits<type>::max(); \
28
+ static constexpr constant type finite_min = \
29
+ metal::numeric_limits<type>::min(); \
30
+ };
31
+
32
+ instantiate_default_limit(uint8_t);
33
+ instantiate_default_limit(uint16_t);
34
+ instantiate_default_limit(uint32_t);
35
+ instantiate_default_limit(uint64_t);
36
+ instantiate_default_limit(int8_t);
37
+ instantiate_default_limit(int16_t);
38
+ instantiate_default_limit(int32_t);
39
+ instantiate_default_limit(int64_t);
40
+
41
+ #define instantiate_float_limit(type) \
42
+ template <> \
43
+ struct Limits<type> { \
44
+ static constexpr constant type max = \
45
+ metal::numeric_limits<type>::infinity(); \
46
+ static constexpr constant type min = \
47
+ -metal::numeric_limits<type>::infinity(); \
48
+ static constexpr constant type finite_max = \
49
+ metal::numeric_limits<type>::max(); \
50
+ static constexpr constant type finite_min = \
51
+ -metal::numeric_limits<type>::max(); \
52
+ };
53
+
54
+ instantiate_float_limit(half);
55
+ instantiate_float_limit(float);
56
+ instantiate_float_limit(bfloat16_t);
57
+
58
+ template <>
59
+ struct Limits<bool> {
60
+ static constexpr constant bool max = true;
61
+ static constexpr constant bool min = false;
62
+ };
63
+
64
+ ///////////////////////////////////////////////////////////////////////////////
65
+ // Indexing utils
66
+ ///////////////////////////////////////////////////////////////////////////////
67
+
68
+ inline size_t elem_to_loc(
69
+ uint elem,
70
+ device const int* shape,
71
+ device const size_t* strides,
72
+ int ndim) {
73
+ size_t loc = 0;
74
+ for (int i = ndim - 1; i >= 0; --i) {
75
+ loc += (elem % shape[i]) * strides[i];
76
+ elem /= shape[i];
77
+ }
78
+ return loc;
79
+ }
80
+
81
+ inline size_t elem_to_loc(
82
+ uint elem,
83
+ constant const int* shape,
84
+ constant const size_t* strides,
85
+ int ndim) {
86
+ size_t loc = 0;
87
+ for (int i = ndim - 1; i >= 0; --i) {
88
+ loc += (elem % shape[i]) * strides[i];
89
+ elem /= shape[i];
90
+ }
91
+ return loc;
92
+ }
93
+
94
+ template <int NDIM>
95
+ inline uint2 elem_to_loc_2_nd(
96
+ uint3 elem,
97
+ constant const int shape[NDIM],
98
+ constant const size_t a_strides[NDIM],
99
+ constant const size_t b_strides[NDIM]) {
100
+ uint2 loc = {
101
+ static_cast<uint>(
102
+ elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
103
+ static_cast<uint>(
104
+ elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
105
+ for (int d = NDIM - 3; d >= 0; --d) {
106
+ uint l = elem.z % shape[d];
107
+ loc.x += l * a_strides[d];
108
+ loc.y += l * b_strides[d];
109
+ elem.z /= shape[d];
110
+ }
111
+ return loc;
112
+ }
113
+
114
+ template <int NDIM>
115
+ inline size_t elem_to_loc_nd(
116
+ uint3 elem,
117
+ constant const int shape[NDIM],
118
+ constant const size_t strides[NDIM]) {
119
+ size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
120
+ for (int d = NDIM - 3; d >= 0; --d) {
121
+ loc += (elem.z % shape[d]) * strides[d];
122
+ elem.z /= shape[d];
123
+ }
124
+ return loc;
125
+ }
126
+
127
+ inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) {
128
+ return elem * stride;
129
+ }
130
+
131
+ inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) {
132
+ return elem.x * strides[1] + elem.y * strides[0];
133
+ }
134
+
135
+ inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) {
136
+ return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
137
+ }
138
+
139
+ // Non templated version to handle arbitrary dims
140
+ inline size_t elem_to_loc(
141
+ uint3 elem,
142
+ constant const int* shape,
143
+ constant const size_t* strides,
144
+ int ndim) {
145
+ size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
146
+ for (int d = ndim - 3; d >= 0; --d) {
147
+ loc += (elem.z % shape[d]) * strides[d];
148
+ elem.z /= shape[d];
149
+ }
150
+ return loc;
151
+ }
152
+
153
+ inline uint2 elem_to_loc_2_nd(
154
+ uint3 elem,
155
+ constant const int* shape,
156
+ constant const size_t* a_strides,
157
+ constant const size_t* b_strides,
158
+ int ndim) {
159
+ uint2 loc = {
160
+ static_cast<uint>(
161
+ elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
162
+ static_cast<uint>(
163
+ elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
164
+ for (int d = ndim - 3; d >= 0; --d) {
165
+ uint l = elem.z % shape[d];
166
+ loc.x += l * a_strides[d];
167
+ loc.y += l * b_strides[d];
168
+ elem.z /= shape[d];
169
+ }
170
+ return loc;
171
+ }
172
+
173
+ template <int NDIM>
174
+ inline uint elem_to_loc_nd(
175
+ uint elem,
176
+ device const int* shape,
177
+ device const size_t* strides);
178
+
179
+ template <>
180
+ inline uint elem_to_loc_nd<1>(
181
+ uint elem,
182
+ device const int* shape,
183
+ device const size_t* strides) {
184
+ return (elem % shape[0]) * strides[0];
185
+ }
186
+
187
+ template <>
188
+ inline uint elem_to_loc_nd<2>(
189
+ uint elem,
190
+ device const int* shape,
191
+ device const size_t* strides) {
192
+ uint loc = (elem % shape[1]) * strides[1];
193
+ elem /= shape[1];
194
+ loc += (elem % shape[0]) * strides[0];
195
+ return loc;
196
+ }
197
+
198
+ template <>
199
+ inline uint elem_to_loc_nd<3>(
200
+ uint elem,
201
+ device const int* shape,
202
+ device const size_t* strides) {
203
+ uint loc = (elem % shape[2]) * strides[2];
204
+ elem /= shape[2];
205
+ loc += (elem % shape[1]) * strides[1];
206
+ elem /= shape[1];
207
+ loc += (elem % shape[0]) * strides[0];
208
+ return loc;
209
+ }
210
+
211
+ template <>
212
+ inline uint elem_to_loc_nd<4>(
213
+ uint elem,
214
+ device const int* shape,
215
+ device const size_t* strides) {
216
+ uint loc = (elem % shape[3]) * strides[3];
217
+ elem /= shape[3];
218
+ loc += (elem % shape[2]) * strides[2];
219
+ elem /= shape[2];
220
+ loc += (elem % shape[1]) * strides[1];
221
+ elem /= shape[1];
222
+ loc += (elem % shape[0]) * strides[0];
223
+ return loc;
224
+ }
225
+
226
+ ///////////////////////////////////////////////////////////////////////////////
227
+ // Calculation utils
228
+ ///////////////////////////////////////////////////////////////////////////////
229
+
230
+ /** Compute ceil((float)N/(float)M) */
231
+ inline size_t ceildiv(size_t N, size_t M) {
232
+ return (N + M - 1) / M;
233
+ }
234
+
235
+ // https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
236
+ inline float log1p(float x) {
237
+ float xp1 = 1.0f + x;
238
+ return (xp1 == 1.0f) ? x : x * (metal::log(xp1) / (xp1 - 1.0f));
239
+ }
240
+
241
+ inline bfloat16_t log1p(bfloat16_t x) {
242
+ float xp1 = 1.0f + static_cast<float>(x);
243
+ bfloat16_t ret =
244
+ (xp1 == 1.0f) ? x : bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
245
+ return ret;
246
+ }
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/matmul.h ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #include <algorithm>
4
+ #include <cassert>
5
+ #include <sstream>
6
+
7
+ #include "mlx/backend/metal/copy.h"
8
+ #include "mlx/backend/metal/device.h"
9
+ #include "mlx/backend/metal/mps/gemm.h"
10
+ #include "mlx/backend/metal/utils.h"
11
+ #include "mlx/utils.h"
12
+
13
+ namespace mlx::core {
14
+
15
+ void mlx_matmul(
16
+ const Stream& s,
17
+ metal::Device& d,
18
+ const array& a,
19
+ const array& b,
20
+ array& out,
21
+ int M,
22
+ int N,
23
+ int K,
24
+ int batch_size_out,
25
+ int lda,
26
+ int ldb,
27
+ bool transpose_a,
28
+ bool transpose_b,
29
+ std::vector<array>& copies);
30
+
31
+ } // namespace mlx::core
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/metal.h ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <future>
6
+ #include <memory>
7
+ #include <vector>
8
+
9
+ #include "mlx/array.h"
10
+ #include "mlx/stream.h"
11
+
12
+ namespace mlx::core::metal {
13
+
14
+ constexpr bool is_available() {
15
+ #ifdef _METAL_
16
+ return true;
17
+ #else
18
+ return false;
19
+ #endif
20
+ }
21
+
22
+ void new_stream(Stream stream);
23
+ std::shared_ptr<void> new_scoped_memory_pool();
24
+
25
+ std::function<void()> make_task(
26
+ array& arr,
27
+ std::vector<std::shared_future<void>> deps,
28
+ std::shared_ptr<std::promise<void>> p,
29
+ bool retain_graph);
30
+
31
+ } // namespace mlx::core::metal
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/mps/gemm.h ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <Metal/Metal.hpp>
6
+
7
+ #define _MPS_PRIVATE_CLS(symbol) (MTL::Private::Class::s_k##symbol)
8
+ #define _MPS_PRIVATE_SEL(accessor) (MTL::Private::Selector::s_k##accessor)
9
+
10
+ namespace MTL::Private::Class {
11
+ _MTL_PRIVATE_DEF_CLS(MPSMatrixDescriptor);
12
+ _MTL_PRIVATE_DEF_CLS(MPSMatrix);
13
+ _MTL_PRIVATE_DEF_CLS(MPSVectorDescriptor);
14
+ _MTL_PRIVATE_DEF_CLS(MPSVector);
15
+ _MTL_PRIVATE_DEF_CLS(MPSKernel);
16
+ _MTL_PRIVATE_DEF_CLS(MPSMatrixMultiplication);
17
+ _MTL_PRIVATE_DEF_CLS(MPSMatrixVectorMultiplication);
18
+ } // namespace MTL::Private::Class
19
+
20
+ namespace MTL::Private::Selector {
21
+ _MTL_PRIVATE_DEF_SEL(
22
+ matrixDescriptorWithRows_columns_rowBytes_dataType,
23
+ "matrixDescriptorWithRows:columns:rowBytes:dataType:");
24
+ _MTL_PRIVATE_DEF_SEL(
25
+ matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType,
26
+ "matrixDescriptorWithRows:columns:matrices:rowBytes:matrixBytes:dataType:");
27
+ _MTL_PRIVATE_DEF_SEL(rows, "rows");
28
+ _MTL_PRIVATE_DEF_SEL(initWithBuffer_descriptor, "initWithBuffer:descriptor:");
29
+ _MTL_PRIVATE_DEF_SEL(
30
+ initWithDevice_,
31
+ "initWithDevice:transposeLeft:transposeRight:"
32
+ "resultRows:resultColumns:interiorColumns:alpha:beta:");
33
+ _MTL_PRIVATE_DEF_SEL(
34
+ encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix,
35
+ "encodeToCommandBuffer:leftMatrix:rightMatrix:resultMatrix:");
36
+ _MTL_PRIVATE_DEF_SEL(setLeftMatrixOrigin_, "setLeftMatrixOrigin:");
37
+ _MTL_PRIVATE_DEF_SEL(setRightMatrixOrigin_, "setRightMatrixOrigin:");
38
+ _MTL_PRIVATE_DEF_SEL(setResultMatrixOrigin_, "setResultMatrixOrigin:");
39
+ _MTL_PRIVATE_DEF_SEL(setBatchStart_, "setBatchStart:");
40
+ _MTL_PRIVATE_DEF_SEL(setBatchSize_, "setBatchSize:");
41
+ _MTL_PRIVATE_DEF_SEL(
42
+ vectorDescriptorWithLength_dataType,
43
+ "vectorDescriptorWithLength:dataType:");
44
+ _MTL_PRIVATE_DEF_SEL(
45
+ vectorDescriptorWithLength_vectors_vectorBytes_dataType,
46
+ "vectorDescriptorWithLength:vectors:vectorBytes:dataType:");
47
+ _MTL_PRIVATE_DEF_SEL(
48
+ initWithDevice_transpose_rows_columns_alpha_beta,
49
+ "initWithDevice:transpose:rows:columns:alpha:beta:");
50
+ _MTL_PRIVATE_DEF_SEL(
51
+ encodeToCommandBuffer_inputMatrix_inputVector_resultVector,
52
+ "encodeToCommandBuffer:inputMatrix:inputVector:resultVector:");
53
+ } // namespace MTL::Private::Selector
54
+
55
+ namespace MPS {
56
+
57
+ typedef enum DataType : uint32_t {
58
+ DataTypeFloatBit = 0x10000000,
59
+ DataTypeAlternateEncodingBit = 0x80000000,
60
+ DataTypeFloat16 = DataTypeFloatBit | 16,
61
+ DataTypeFloat32 = DataTypeFloatBit | 32,
62
+ DataTypeBFloat16 = DataTypeAlternateEncodingBit | DataTypeFloat16
63
+ } DataType;
64
+
65
+ class MatrixDescriptor : public NS::Copying<MatrixDescriptor> {
66
+ public:
67
+ static class MatrixDescriptor* matrixDescriptor(
68
+ NS::UInteger rows,
69
+ NS::UInteger columns,
70
+ NS::UInteger rowBytes,
71
+ NS::UInteger dataType);
72
+ static class MatrixDescriptor* matrixDescriptor(
73
+ NS::UInteger rows,
74
+ NS::UInteger columns,
75
+ NS::UInteger matrices,
76
+ NS::UInteger rowBytes,
77
+ NS::UInteger matrixBytes,
78
+ NS::UInteger dataType);
79
+ NS::UInteger rows() const;
80
+ };
81
+
82
+ class Matrix : public NS::Referencing<Matrix> {
83
+ public:
84
+ static class Matrix* alloc();
85
+ Matrix* init(MTL::Buffer* buffer, MatrixDescriptor* descriptor);
86
+ Matrix* init(const MTL::Buffer* buffer, MatrixDescriptor* descriptor);
87
+ };
88
+
89
+ class Kernel : public NS::Referencing<Kernel> {
90
+ public:
91
+ NS::String* label() const;
92
+ MTL::Device* device() const;
93
+ };
94
+
95
+ class MatrixMultiplication
96
+ : public NS::Referencing<MatrixMultiplication, Kernel> {
97
+ public:
98
+ static class MatrixMultiplication* alloc();
99
+
100
+ MatrixMultiplication* init(
101
+ MTL::Device* device,
102
+ bool transposeLeft,
103
+ bool transposeRight,
104
+ NS::UInteger resultRows,
105
+ NS::UInteger resultColumns,
106
+ NS::UInteger interiorColumns,
107
+ double alpha,
108
+ double beta);
109
+
110
+ void encodeToCommandBuffer(
111
+ MTL::CommandBuffer* commandBuffer,
112
+ Matrix* leftMatrix,
113
+ Matrix* rightMatrix,
114
+ Matrix* resultMatrix);
115
+
116
+ void setLeftMatrixOrigin(MTL::Origin origin);
117
+ void setRightMatrixOrigin(MTL::Origin origin);
118
+ void setResultMatrixOrigin(MTL::Origin origin);
119
+ void setBatchStart(NS::UInteger batchStart);
120
+ void setBatchSize(NS::UInteger batchSize);
121
+ };
122
+
123
+ class VectorDescriptor : public NS::Copying<VectorDescriptor> {
124
+ public:
125
+ static class VectorDescriptor* vectorDescriptor(
126
+ NS::UInteger length,
127
+ NS::UInteger dataType);
128
+ static class VectorDescriptor* vectorDescriptor(
129
+ NS::UInteger length,
130
+ NS::UInteger vectors,
131
+ NS::UInteger vectorBytes,
132
+ NS::UInteger dataType);
133
+ };
134
+
135
+ class Vector : public NS::Referencing<Vector> {
136
+ public:
137
+ static class Vector* alloc();
138
+ Vector* init(MTL::Buffer* buffer, VectorDescriptor* descriptor);
139
+ Vector* init(const MTL::Buffer* buffer, VectorDescriptor* descriptor);
140
+ };
141
+
142
+ class MatrixVectorMultiplication
143
+ : public NS::Referencing<MatrixVectorMultiplication, Kernel> {
144
+ public:
145
+ static class MatrixVectorMultiplication* alloc();
146
+
147
+ MatrixVectorMultiplication* init(
148
+ MTL::Device* device,
149
+ bool transpose,
150
+ NS::UInteger rows,
151
+ NS::UInteger columns,
152
+ double alpha,
153
+ double beta);
154
+
155
+ void encodeToCommandBuffer(
156
+ MTL::CommandBuffer* commandBuffer,
157
+ Matrix* inputMatrix,
158
+ Vector* inputVector,
159
+ Vector* resultVector);
160
+ };
161
+
162
+ _MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
163
+ NS::UInteger rows,
164
+ NS::UInteger columns,
165
+ NS::UInteger rowBytes,
166
+ NS::UInteger dataType) {
167
+ return Object::sendMessage<MatrixDescriptor*>(
168
+ _MPS_PRIVATE_CLS(MPSMatrixDescriptor),
169
+ _MPS_PRIVATE_SEL(matrixDescriptorWithRows_columns_rowBytes_dataType),
170
+ rows,
171
+ columns,
172
+ rowBytes,
173
+ dataType);
174
+ }
175
+
176
+ _MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
177
+ NS::UInteger rows,
178
+ NS::UInteger columns,
179
+ NS::UInteger matrices,
180
+ NS::UInteger rowBytes,
181
+ NS::UInteger matrixBytes,
182
+ NS::UInteger dataType) {
183
+ return Object::sendMessage<MatrixDescriptor*>(
184
+ _MPS_PRIVATE_CLS(MPSMatrixDescriptor),
185
+ _MPS_PRIVATE_SEL(
186
+ matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType),
187
+ rows,
188
+ columns,
189
+ matrices,
190
+ rowBytes,
191
+ matrixBytes,
192
+ dataType);
193
+ }
194
+
195
+ _MTL_INLINE NS::UInteger MatrixDescriptor::rows() const {
196
+ return Object::sendMessage<NS::UInteger>(this, _MPS_PRIVATE_SEL(rows));
197
+ }
198
+
199
+ _MTL_INLINE Matrix* Matrix::alloc() {
200
+ return NS::Object::alloc<Matrix>(_MPS_PRIVATE_CLS(MPSMatrix));
201
+ }
202
+
203
+ _MTL_INLINE Matrix* Matrix::init(
204
+ MTL::Buffer* buffer,
205
+ MatrixDescriptor* descriptor) {
206
+ return Object::sendMessage<Matrix*>(
207
+ this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
208
+ }
209
+
210
+ _MTL_INLINE Matrix* Matrix::init(
211
+ const MTL::Buffer* buffer,
212
+ MatrixDescriptor* descriptor) {
213
+ return init(const_cast<MTL::Buffer*>(buffer), descriptor);
214
+ }
215
+
216
+ _MTL_INLINE NS::String* Kernel::label() const {
217
+ return Object::sendMessage<NS::String*>(this, _MPS_PRIVATE_SEL(label));
218
+ }
219
+
220
+ _MTL_INLINE MTL::Device* Kernel::device() const {
221
+ return Object::sendMessage<MTL::Device*>(this, _MPS_PRIVATE_SEL(device));
222
+ }
223
+
224
+ _MTL_INLINE MatrixMultiplication* MatrixMultiplication::alloc() {
225
+ return NS::Object::alloc<MatrixMultiplication>(
226
+ _MPS_PRIVATE_CLS(MPSMatrixMultiplication));
227
+ }
228
+
229
+ _MTL_INLINE MatrixMultiplication* MatrixMultiplication::init(
230
+ MTL::Device* device,
231
+ bool transposeLeft,
232
+ bool transposeRight,
233
+ NS::UInteger resultRows,
234
+ NS::UInteger resultColumns,
235
+ NS::UInteger interiorColumns,
236
+ double alpha,
237
+ double beta) {
238
+ return Object::sendMessage<MatrixMultiplication*>(
239
+ this,
240
+ _MPS_PRIVATE_SEL(initWithDevice_),
241
+ device,
242
+ transposeLeft,
243
+ transposeRight,
244
+ resultRows,
245
+ resultColumns,
246
+ interiorColumns,
247
+ alpha,
248
+ beta);
249
+ }
250
+
251
+ _MTL_INLINE void MatrixMultiplication::encodeToCommandBuffer(
252
+ MTL::CommandBuffer* commandBuffer,
253
+ Matrix* leftMatrix,
254
+ Matrix* rightMatrix,
255
+ Matrix* resultMatrix) {
256
+ return Object::sendMessage<void>(
257
+ this,
258
+ _MPS_PRIVATE_SEL(
259
+ encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix),
260
+ commandBuffer,
261
+ leftMatrix,
262
+ rightMatrix,
263
+ resultMatrix);
264
+ }
265
+
266
+ _MTL_INLINE void MatrixMultiplication::setLeftMatrixOrigin(MTL::Origin origin) {
267
+ Object::sendMessage<void>(
268
+ this, _MPS_PRIVATE_SEL(setLeftMatrixOrigin_), origin);
269
+ }
270
+
271
+ _MTL_INLINE void MatrixMultiplication::setRightMatrixOrigin(
272
+ MTL::Origin origin) {
273
+ Object::sendMessage<void>(
274
+ this, _MPS_PRIVATE_SEL(setRightMatrixOrigin_), origin);
275
+ }
276
+
277
+ _MTL_INLINE void MatrixMultiplication::setResultMatrixOrigin(
278
+ MTL::Origin origin) {
279
+ Object::sendMessage<void>(
280
+ this, _MPS_PRIVATE_SEL(setResultMatrixOrigin_), origin);
281
+ }
282
+
283
+ _MTL_INLINE void MatrixMultiplication::setBatchStart(NS::UInteger batchStart) {
284
+ Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchStart_), batchStart);
285
+ }
286
+
287
+ _MTL_INLINE void MatrixMultiplication::setBatchSize(NS::UInteger batchSize) {
288
+ Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchSize_), batchSize);
289
+ }
290
+
291
+ _MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
292
+ NS::UInteger length,
293
+ NS::UInteger dataType) {
294
+ return Object::sendMessage<VectorDescriptor*>(
295
+ _MPS_PRIVATE_CLS(MPSVectorDescriptor),
296
+ _MPS_PRIVATE_SEL(vectorDescriptorWithLength_dataType),
297
+ length,
298
+ dataType);
299
+ }
300
+
301
+ _MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
302
+ NS::UInteger length,
303
+ NS::UInteger vectors,
304
+ NS::UInteger vectorBytes,
305
+ NS::UInteger dataType) {
306
+ return Object::sendMessage<VectorDescriptor*>(
307
+ _MPS_PRIVATE_CLS(MPSVectorDescriptor),
308
+ _MPS_PRIVATE_SEL(vectorDescriptorWithLength_vectors_vectorBytes_dataType),
309
+ length,
310
+ vectors,
311
+ vectorBytes,
312
+ dataType);
313
+ }
314
+
315
+ _MTL_INLINE Vector* Vector::alloc() {
316
+ return NS::Object::alloc<Vector>(_MPS_PRIVATE_CLS(MPSVector));
317
+ }
318
+
319
+ _MTL_INLINE Vector* Vector::init(
320
+ MTL::Buffer* buffer,
321
+ VectorDescriptor* descriptor) {
322
+ return Object::sendMessage<Vector*>(
323
+ this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
324
+ }
325
+
326
+ _MTL_INLINE Vector* Vector::init(
327
+ const MTL::Buffer* buffer,
328
+ VectorDescriptor* descriptor) {
329
+ return init(const_cast<MTL::Buffer*>(buffer), descriptor);
330
+ }
331
+
332
+ _MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::alloc() {
333
+ return NS::Object::alloc<MatrixVectorMultiplication>(
334
+ _MPS_PRIVATE_CLS(MPSMatrixVectorMultiplication));
335
+ }
336
+
337
+ _MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::init(
338
+ MTL::Device* device,
339
+ bool transpose,
340
+ NS::UInteger rows,
341
+ NS::UInteger columns,
342
+ double alpha,
343
+ double beta) {
344
+ return Object::sendMessage<MatrixVectorMultiplication*>(
345
+ this,
346
+ _MPS_PRIVATE_SEL(initWithDevice_transpose_rows_columns_alpha_beta),
347
+ device,
348
+ transpose,
349
+ rows,
350
+ columns,
351
+ alpha,
352
+ beta);
353
+ }
354
+
355
+ _MTL_INLINE void MatrixVectorMultiplication::encodeToCommandBuffer(
356
+ MTL::CommandBuffer* commandBuffer,
357
+ Matrix* inputMatrix,
358
+ Vector* inputVector,
359
+ Vector* resultVector) {
360
+ return Object::sendMessage<void>(
361
+ this,
362
+ _MPS_PRIVATE_SEL(
363
+ encodeToCommandBuffer_inputMatrix_inputVector_resultVector),
364
+ commandBuffer,
365
+ inputMatrix,
366
+ inputVector,
367
+ resultVector);
368
+ }
369
+
370
+ } // namespace MPS
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/utils.h ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/array.h"
6
+ #include "mlx/backend/metal/device.h"
7
+
8
+ namespace mlx::core {
9
+
10
+ namespace {
11
+
12
+ void set_array_buffer(
13
+ MTL::ComputeCommandEncoder* compute_encoder,
14
+ MTL::ArgumentEncoder* enc,
15
+ const array& a,
16
+ int idx) {
17
+ auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
18
+ auto offset = a.data<char>() -
19
+ static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
20
+ enc->setBuffer(a_buf, offset, idx);
21
+ // MTL::Resource usage through argument buffer needs to be explicitly
22
+ // flagged to enable hazard tracking
23
+ compute_encoder->useResource(a_buf, MTL::ResourceUsageRead);
24
+ }
25
+
26
+ void set_array_buffer(
27
+ MTL::ComputeCommandEncoder* enc,
28
+ const array& a,
29
+ int idx) {
30
+ auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
31
+ auto offset = a.data<char>() -
32
+ static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
33
+ enc->setBuffer(a_buf, offset, idx);
34
+ }
35
+
36
+ std::string type_to_name(const array& a) {
37
+ std::string tname;
38
+ switch (a.dtype()) {
39
+ case bool_:
40
+ tname = "bool_";
41
+ break;
42
+ case uint8:
43
+ tname = "uint8";
44
+ break;
45
+ case uint16:
46
+ tname = "uint16";
47
+ break;
48
+ case uint32:
49
+ tname = "uint32";
50
+ break;
51
+ case uint64:
52
+ tname = "uint64";
53
+ break;
54
+ case int8:
55
+ tname = "int8";
56
+ break;
57
+ case int16:
58
+ tname = "int16";
59
+ break;
60
+ case int32:
61
+ tname = "int32";
62
+ break;
63
+ case int64:
64
+ tname = "int64";
65
+ break;
66
+ case float16:
67
+ tname = "float16";
68
+ break;
69
+ case float32:
70
+ tname = "float32";
71
+ break;
72
+ case bfloat16:
73
+ tname = "bfloat16";
74
+ break;
75
+ case complex64:
76
+ tname = "complex64";
77
+ break;
78
+ }
79
+ return tname;
80
+ }
81
+
82
+ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
83
+ int pows[3] = {0, 0, 0};
84
+ int sum = 0;
85
+ while (true) {
86
+ int presum = sum;
87
+ // Check all the pows
88
+ if (dim0 >= (1 << (pows[0] + 1))) {
89
+ pows[0]++;
90
+ sum++;
91
+ }
92
+ if (sum == 10) {
93
+ break;
94
+ }
95
+ if (dim1 >= (1 << (pows[1] + 1))) {
96
+ pows[1]++;
97
+ sum++;
98
+ }
99
+ if (sum == 10) {
100
+ break;
101
+ }
102
+ if (dim2 >= (1 << (pows[2] + 1))) {
103
+ pows[2]++;
104
+ sum++;
105
+ }
106
+ if (sum == presum || sum == 10) {
107
+ break;
108
+ }
109
+ }
110
+ return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
111
+ }
112
+
113
+ // Collapse dims that are contiguous to possibly route to a better kernel
114
+ // e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
115
+ // should return {{2, 4}, {{1, 2}}}.
116
+ //
117
+ // When multiple arrays are passed they should all have the same shape. The
118
+ // collapsed axes are also the same so one shape is returned.
119
+ std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
120
+ collapse_contiguous_dims(const std::vector<array>& xs) {
121
+ // Make a vector that has axes separated with -1. Collapse all axes between
122
+ // -1.
123
+ std::vector<int> to_collapse;
124
+ if (xs[0].ndim() > 0) {
125
+ to_collapse.push_back(0);
126
+ for (int i = 1; i < xs[0].ndim(); i++) {
127
+ bool contiguous = true;
128
+ for (auto& x : xs) {
129
+ if (x.strides()[i] * x.shape()[i] != x.strides()[i - 1]) {
130
+ contiguous = false;
131
+ }
132
+ if (!contiguous) {
133
+ break;
134
+ }
135
+ }
136
+ if (!contiguous) {
137
+ to_collapse.push_back(-1);
138
+ }
139
+ to_collapse.push_back(i);
140
+ }
141
+ to_collapse.push_back(-1);
142
+ }
143
+
144
+ std::vector<int> out_shape;
145
+ std::vector<std::vector<size_t>> out_strides(xs.size());
146
+ for (int i = 0; i < to_collapse.size(); i++) {
147
+ int current_shape = xs[0].shape()[to_collapse[i]];
148
+ while (to_collapse[++i] != -1) {
149
+ current_shape *= xs[0].shape()[to_collapse[i]];
150
+ }
151
+ out_shape.push_back(current_shape);
152
+ for (int j = 0; j < xs.size(); j++) {
153
+ out_strides[j].push_back(xs[j].strides()[to_collapse[i - 1]]);
154
+ }
155
+ }
156
+
157
+ return std::make_tuple(out_shape, out_strides);
158
+ }
159
+
160
+ template <typename... Arrays>
161
+ std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
162
+ collapse_contiguous_dims(Arrays... xs) {
163
+ return collapse_contiguous_dims(
164
+ std::vector<array>{std::forward<Arrays>(xs)...});
165
+ }
166
+
167
+ } // namespace
168
+
169
+ } // namespace mlx::core
lib/python3.11/site-packages/mlx/include/mlx/device.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ namespace mlx::core {
6
+
7
+ struct Device {
8
+ enum class DeviceType {
9
+ cpu,
10
+ gpu,
11
+ };
12
+
13
+ static constexpr DeviceType cpu = DeviceType::cpu;
14
+ static constexpr DeviceType gpu = DeviceType::gpu;
15
+
16
+ Device(DeviceType type, int index = 0) : type(type), index(index){};
17
+
18
+ DeviceType type;
19
+ int index;
20
+ };
21
+
22
+ const Device& default_device();
23
+
24
+ void set_default_device(const Device& d);
25
+
26
+ bool operator==(const Device& lhs, const Device& rhs);
27
+ bool operator!=(const Device& lhs, const Device& rhs);
28
+
29
+ } // namespace mlx::core
lib/python3.11/site-packages/mlx/include/mlx/dtype.h ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <complex>
6
+ #include <cstdint>
7
+ #include <ostream>
8
+ #include <string>
9
+
10
+ #include "mlx/types/complex.h"
11
+ #include "mlx/types/half_types.h"
12
+
13
+ namespace mlx::core {
14
+
15
+ struct Dtype {
16
+ enum class Val {
17
+ bool_,
18
+ uint8,
19
+ uint16,
20
+ uint32,
21
+ uint64,
22
+ int8,
23
+ int16,
24
+ int32,
25
+ int64,
26
+ float16,
27
+ float32,
28
+ bfloat16,
29
+ complex64,
30
+ };
31
+
32
+ enum class Kind {
33
+ b, /* bool */
34
+ u, /* unsigned int */
35
+ i, /* signed int */
36
+ f, /* float */
37
+ c, /* complex */
38
+ V, /* void - used for brain float */
39
+ };
40
+
41
+ Val val;
42
+ const uint8_t size;
43
+ constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){};
44
+ constexpr operator Val() const {
45
+ return val;
46
+ };
47
+ };
48
+
49
+ inline bool is_available(const Dtype& dtype) {
50
+ return true;
51
+ }
52
+
53
+ static constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
54
+
55
+ static constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)};
56
+ static constexpr Dtype uint16{Dtype::Val::uint16, sizeof(uint16_t)};
57
+ static constexpr Dtype uint32{Dtype::Val::uint32, sizeof(uint32_t)};
58
+ static constexpr Dtype uint64{Dtype::Val::uint64, sizeof(uint64_t)};
59
+
60
+ static constexpr Dtype int8{Dtype::Val::int8, sizeof(int8_t)};
61
+ static constexpr Dtype int16{Dtype::Val::int16, sizeof(int16_t)};
62
+ static constexpr Dtype int32{Dtype::Val::int32, sizeof(int32_t)};
63
+ static constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)};
64
+
65
+ static constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)};
66
+ static constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
67
+ static constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
68
+ static constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};
69
+
70
+ Dtype promote_types(const Dtype& t1, const Dtype& t2);
71
+
72
+ inline uint8_t size_of(const Dtype& t) {
73
+ return t.size;
74
+ }
75
+
76
+ Dtype::Kind kindof(const Dtype& t);
77
+
78
+ inline bool is_unsigned(const Dtype& t) {
79
+ return kindof(t) == Dtype::Kind::u || kindof(t) == Dtype::Kind::b;
80
+ }
81
+
82
+ inline bool is_floating_point(const Dtype& t) {
83
+ return kindof(t) == Dtype::Kind::f || kindof(t) == Dtype::Kind::V ||
84
+ kindof(t) == Dtype::Kind::c;
85
+ }
86
+
87
+ inline bool is_complex(const Dtype& t) {
88
+ return kindof(t) == Dtype::Kind::c;
89
+ }
90
+
91
+ inline bool is_integral(const Dtype& t) {
92
+ return !(is_floating_point(t));
93
+ }
94
+
95
+ template <typename T>
96
+ struct TypeToDtype {
97
+ operator Dtype();
98
+ };
99
+
100
+ // Array protocol typestring for Dtype
101
+ std::string dtype_to_array_protocol(const Dtype& t);
102
+ // Dtype from array protocol type string
103
+ Dtype dtype_from_array_protocol(const std::string& t);
104
+
105
+ } // namespace mlx::core
lib/python3.11/site-packages/mlx/include/mlx/fft.h ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <variant>
6
+
7
+ #include "array.h"
8
+ #include "device.h"
9
+ #include "stream.h"
10
+
11
+ namespace mlx::core::fft {
12
+
13
+ using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
14
+
15
+ /** Compute the n-dimensional Fourier Transform. */
16
+ array fftn(
17
+ const array& a,
18
+ const std::vector<int>& n,
19
+ const std::vector<int>& axes,
20
+ StreamOrDevice s = {});
21
+ array fftn(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});
22
+ array fftn(const array& a, StreamOrDevice s = {});
23
+
24
+ /** Compute the n-dimensional inverse Fourier Transform. */
25
+ array ifftn(
26
+ const array& a,
27
+ const std::vector<int>& n,
28
+ const std::vector<int>& axes,
29
+ StreamOrDevice s = {});
30
+ array ifftn(
31
+ const array& a,
32
+ const std::vector<int>& axes,
33
+ StreamOrDevice s = {});
34
+ array ifftn(const array& a, StreamOrDevice s = {});
35
+
36
+ /** Compute the one-dimensional Fourier Transform. */
37
+ inline array fft(const array& a, int n, int axis, StreamOrDevice s = {}) {
38
+ return fftn(a, {n}, {axis}, s);
39
+ }
40
+ inline array fft(const array& a, int axis = -1, StreamOrDevice s = {}) {
41
+ return fftn(a, {axis}, s);
42
+ }
43
+
44
+ /** Compute the one-dimensional inverse Fourier Transform. */
45
+ inline array ifft(const array& a, int n, int axis, StreamOrDevice s = {}) {
46
+ return ifftn(a, {n}, {axis}, s);
47
+ }
48
+ inline array ifft(const array& a, int axis = -1, StreamOrDevice s = {}) {
49
+ return ifftn(a, {axis}, s);
50
+ }
51
+
52
+ /** Compute the two-dimensional Fourier Transform. */
53
+ inline array fft2(
54
+ const array& a,
55
+ const std::vector<int>& n,
56
+ const std::vector<int>& axes,
57
+ StreamOrDevice s = {}) {
58
+ return fftn(a, n, axes, s);
59
+ }
60
+ inline array fft2(
61
+ const array& a,
62
+ const std::vector<int>& axes = {-2, -1},
63
+ StreamOrDevice s = {}) {
64
+ return fftn(a, axes, s);
65
+ }
66
+
67
+ /** Compute the two-dimensional inverse Fourier Transform. */
68
+ inline array ifft2(
69
+ const array& a,
70
+ const std::vector<int>& n,
71
+ const std::vector<int>& axes,
72
+ StreamOrDevice s = {}) {
73
+ return ifftn(a, n, axes, s);
74
+ }
75
+ inline array ifft2(
76
+ const array& a,
77
+ const std::vector<int>& axes = {-2, -1},
78
+ StreamOrDevice s = {}) {
79
+ return ifftn(a, axes, s);
80
+ }
81
+
82
+ /** Compute the n-dimensional Fourier Transform on a real input. */
83
+ array rfftn(
84
+ const array& a,
85
+ const std::vector<int>& n,
86
+ const std::vector<int>& axes,
87
+ StreamOrDevice s = {});
88
+ array rfftn(
89
+ const array& a,
90
+ const std::vector<int>& axes,
91
+ StreamOrDevice s = {});
92
+ array rfftn(const array& a, StreamOrDevice s = {});
93
+
94
+ /** Compute the n-dimensional inverse of `rfftn`. */
95
+ array irfftn(
96
+ const array& a,
97
+ const std::vector<int>& n,
98
+ const std::vector<int>& axes,
99
+ StreamOrDevice s = {});
100
+ array irfftn(
101
+ const array& a,
102
+ const std::vector<int>& axes,
103
+ StreamOrDevice s = {});
104
+ array irfftn(const array& a, StreamOrDevice s = {});
105
+
106
+ /** Compute the one-dimensional Fourier Transform on a real input. */
107
+ inline array rfft(const array& a, int n, int axis, StreamOrDevice s = {}) {
108
+ return rfftn(a, {n}, {axis}, s);
109
+ }
110
+ inline array rfft(const array& a, int axis = -1, StreamOrDevice s = {}) {
111
+ return rfftn(a, {axis}, s);
112
+ }
113
+ /** Compute the one-dimensional inverse of `rfft`. */
114
+ inline array irfft(const array& a, int n, int axis, StreamOrDevice s = {}) {
115
+ return irfftn(a, {n}, {axis}, s);
116
+ }
117
+ inline array irfft(const array& a, int axis = -1, StreamOrDevice s = {}) {
118
+ return irfftn(a, {axis}, s);
119
+ }
120
+
121
+ /** Compute the two-dimensional Fourier Transform on a real input. */
122
+ inline array rfft2(
123
+ const array& a,
124
+ const std::vector<int>& n,
125
+ const std::vector<int>& axes,
126
+ StreamOrDevice s = {}) {
127
+ return rfftn(a, n, axes, s);
128
+ }
129
+ inline array rfft2(
130
+ const array& a,
131
+ const std::vector<int>& axes = {-2, -1},
132
+ StreamOrDevice s = {}) {
133
+ return rfftn(a, axes, s);
134
+ }
135
+
136
+ /** Compute the two-dimensional inverse of `rfft2`. */
137
+ inline array irfft2(
138
+ const array& a,
139
+ const std::vector<int>& n,
140
+ const std::vector<int>& axes,
141
+ StreamOrDevice s = {}) {
142
+ return irfftn(a, n, axes, s);
143
+ }
144
+ inline array irfft2(
145
+ const array& a,
146
+ const std::vector<int>& axes = {-2, -1},
147
+ StreamOrDevice s = {}) {
148
+ return irfftn(a, axes, s);
149
+ }
150
+
151
+ } // namespace mlx::core::fft
lib/python3.11/site-packages/mlx/include/mlx/graph_utils.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/array.h"
6
+
7
+ namespace mlx::core {
8
+
9
+ void print_graph(std::ostream& os, const std::vector<array>& outputs);
10
+
11
+ template <typename... Arrays>
12
+ void print_graph(std::ostream& os, Arrays... outputs) {
13
+ print_graph(os, std::vector<array>{std::forward<Arrays>(outputs)...});
14
+ }
15
+
16
+ void export_to_dot(std::ostream& os, const std::vector<array>& outputs);
17
+
18
+ template <typename... Arrays>
19
+ void export_to_dot(std::ostream& os, Arrays... outputs) {
20
+ export_to_dot(os, std::vector<array>{std::forward<Arrays>(outputs)...});
21
+ }
22
+
23
+ } // namespace mlx::core
lib/python3.11/site-packages/mlx/include/mlx/io/load.h ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <fstream>
6
+ #include <istream>
7
+ #include <memory>
8
+
9
+ namespace mlx::core {
10
+
11
+ namespace io {
12
+
13
+ class Reader {
14
+ public:
15
+ virtual bool is_open() const = 0;
16
+ virtual bool good() const = 0;
17
+ virtual size_t tell() const = 0;
18
+ virtual void seek(
19
+ int64_t off,
20
+ std::ios_base::seekdir way = std::ios_base::beg) = 0;
21
+ virtual void read(char* data, size_t n) = 0;
22
+ virtual std::string label() const = 0;
23
+ };
24
+
25
+ class Writer {
26
+ public:
27
+ virtual bool is_open() const = 0;
28
+ virtual bool good() const = 0;
29
+ virtual size_t tell() const = 0;
30
+ virtual void seek(
31
+ int64_t off,
32
+ std::ios_base::seekdir way = std::ios_base::beg) = 0;
33
+ virtual void write(const char* data, size_t n) = 0;
34
+ virtual std::string label() const = 0;
35
+ };
36
+
37
+ class FileReader : public Reader {
38
+ public:
39
+ explicit FileReader(const std::shared_ptr<std::ifstream>& is)
40
+ : is_(is), label_("stream") {}
41
+ explicit FileReader(const std::string& file_path)
42
+ : is_(std::make_shared<std::ifstream>(file_path, std::ios::binary)),
43
+ label_(file_path) {}
44
+
45
+ bool is_open() const override {
46
+ return is_->is_open();
47
+ }
48
+
49
+ bool good() const override {
50
+ return is_->good();
51
+ }
52
+
53
+ size_t tell() const override {
54
+ return is_->tellg();
55
+ }
56
+
57
+ void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
58
+ override {
59
+ is_->seekg(off, way);
60
+ }
61
+
62
+ void read(char* data, size_t n) override {
63
+ is_->read(data, n);
64
+ }
65
+
66
+ std::string label() const override {
67
+ return "file " + label_;
68
+ }
69
+
70
+ private:
71
+ std::shared_ptr<std::ifstream> is_;
72
+ std::string label_;
73
+ };
74
+
75
+ class FileWriter : public Writer {
76
+ public:
77
+ explicit FileWriter(const std::shared_ptr<std::ofstream>& is)
78
+ : os_(is), label_("stream") {}
79
+ explicit FileWriter(const std::string& file_path)
80
+ : os_(std::make_shared<std::ofstream>(file_path, std::ios::binary)),
81
+ label_(file_path) {}
82
+
83
+ bool is_open() const override {
84
+ return os_->is_open();
85
+ }
86
+
87
+ bool good() const override {
88
+ return os_->good();
89
+ }
90
+
91
+ size_t tell() const override {
92
+ return os_->tellp();
93
+ }
94
+
95
+ void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
96
+ override {
97
+ os_->seekp(off, way);
98
+ }
99
+
100
+ void write(const char* data, size_t n) override {
101
+ os_->write(data, n);
102
+ }
103
+
104
+ std::string label() const override {
105
+ return "file " + label_;
106
+ }
107
+
108
+ private:
109
+ std::shared_ptr<std::ofstream> os_;
110
+ std::string label_;
111
+ };
112
+
113
+ } // namespace io
114
+ } // namespace mlx::core
lib/python3.11/site-packages/mlx/include/mlx/io/safetensor.h ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <json.hpp>
6
+
7
+ #include "mlx/io/load.h"
8
+ #include "mlx/ops.h"
9
+ #include "mlx/primitives.h"
10
+
11
+ using json = nlohmann::json;
12
+
13
+ namespace mlx::core {
14
+
15
+ #define ST_F16 "F16"
16
+ #define ST_BF16 "BF16"
17
+ #define ST_F32 "F32"
18
+
19
+ #define ST_BOOL "BOOL"
20
+ #define ST_I8 "I8"
21
+ #define ST_I16 "I16"
22
+ #define ST_I32 "I32"
23
+ #define ST_I64 "I64"
24
+ #define ST_U8 "U8"
25
+ #define ST_U16 "U16"
26
+ #define ST_U32 "U32"
27
+ #define ST_U64 "U64"
28
+
29
+ // Note: Complex numbers aren't in the spec yet so this could change -
30
+ // https://github.com/huggingface/safetensors/issues/389
31
+ #define ST_C64 "C64"
32
+ } // namespace mlx::core
lib/python3.11/site-packages/mlx/include/mlx/linalg.h ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <optional>
6
+
7
+ #include "mlx/array.h"
8
+ #include "mlx/device.h"
9
+ #include "mlx/ops.h"
10
+ #include "mlx/stream.h"
11
+
12
+ namespace mlx::core::linalg {
13
+
14
+ /**
15
+ * Compute vector or matrix norms.
16
+ *
17
+ * - If axis and ord are both unspecified, computes the 2-norm of flatten(x).
18
+ * - If axis is not provided but ord is, then x must be either 1D or 2D.
19
+ * - If axis is provided, but ord is not, then the 2-norm (or Frobenius norm
20
+ * for matrices) is computed along the given axes. At most 2 axes can be
21
+ * specified.
22
+ * - If both axis and ord are provided, then the corresponding matrix or vector
23
+ * norm is computed. At most 2 axes can be specified.
24
+ */
25
+ array norm(
26
+ const array& a,
27
+ const double ord,
28
+ const std::optional<std::vector<int>>& axis = std::nullopt,
29
+ bool keepdims = false,
30
+ StreamOrDevice s = {});
31
+ inline array norm(
32
+ const array& a,
33
+ const double ord,
34
+ int axis,
35
+ bool keepdims = false,
36
+ StreamOrDevice s = {}) {
37
+ return norm(a, ord, std::vector<int>{axis}, keepdims, s);
38
+ }
39
+ array norm(
40
+ const array& a,
41
+ const std::string& ord,
42
+ const std::optional<std::vector<int>>& axis = std::nullopt,
43
+ bool keepdims = false,
44
+ StreamOrDevice s = {});
45
+ inline array norm(
46
+ const array& a,
47
+ const std::string& ord,
48
+ int axis,
49
+ bool keepdims = false,
50
+ StreamOrDevice s = {}) {
51
+ return norm(a, ord, std::vector<int>{axis}, keepdims, s);
52
+ }
53
+ array norm(
54
+ const array& a,
55
+ const std::optional<std::vector<int>>& axis = std::nullopt,
56
+ bool keepdims = false,
57
+ StreamOrDevice s = {});
58
+ inline array
59
+ norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) {
60
+ return norm(a, std::vector<int>{axis}, keepdims, s);
61
+ }
62
+
63
+ } // namespace mlx::core::linalg
lib/python3.11/site-packages/mlx/include/mlx/mlx.h ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/array.h"
6
+ #include "mlx/backend/metal/metal.h"
7
+ #include "mlx/device.h"
8
+ #include "mlx/fft.h"
9
+ #include "mlx/linalg.h"
10
+ #include "mlx/ops.h"
11
+ #include "mlx/random.h"
12
+ #include "mlx/stream.h"
13
+ #include "mlx/transforms.h"
14
+ #include "mlx/utils.h"
lib/python3.11/site-packages/mlx/include/mlx/ops.h ADDED
@@ -0,0 +1,1094 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <optional>
6
+ #include <variant>
7
+
8
+ #include "array.h"
9
+ #include "device.h"
10
+ #include "io/load.h"
11
+ #include "stream.h"
12
+
13
+ namespace mlx::core {
14
+
15
+ using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
16
+
17
+ Stream to_stream(StreamOrDevice s);
18
+
19
+ /** Creation operations */
20
+
21
+ /**
22
+ * A 1D array of numbers starting at `start` (optional),
23
+ * stopping at stop, stepping by `step` (optional). */
24
+ array arange(
25
+ double start,
26
+ double stop,
27
+ double step,
28
+ Dtype dtype,
29
+ StreamOrDevice s = {});
30
+ array arange(double start, double stop, double step, StreamOrDevice s = {});
31
+ array arange(double start, double stop, Dtype dtype, StreamOrDevice s = {});
32
+ array arange(double start, double stop, StreamOrDevice s = {});
33
+ array arange(double stop, Dtype dtype, StreamOrDevice s = {});
34
+ array arange(double stop, StreamOrDevice s = {});
35
+
36
+ array arange(int start, int stop, int step, StreamOrDevice s = {});
37
+ array arange(int start, int stop, StreamOrDevice s = {});
38
+ array arange(int stop, StreamOrDevice s = {});
39
+
40
+ /** A 1D array of `num` evenly spaced numbers in the range `[start, stop]` */
41
+ array linspace(
42
+ double start,
43
+ double stop,
44
+ int num = 50,
45
+ Dtype dtype = float32,
46
+ StreamOrDevice s = {});
47
+
48
+ /** Convert an array to the given data type. */
49
+ array astype(const array& a, Dtype dtype, StreamOrDevice s = {});
50
+
51
+ /** Create a view of an array with the given shape and strides. */
52
+ array as_strided(
53
+ const array& a,
54
+ std::vector<int> shape,
55
+ std::vector<size_t> strides,
56
+ size_t offset,
57
+ StreamOrDevice s = {});
58
+
59
+ /** Copy another array. */
60
+ array copy(const array& a, StreamOrDevice s = {});
61
+
62
+ /** Fill an array of the given shape with the given value(s). */
63
+ array full(
64
+ const std::vector<int>& shape,
65
+ const array& vals,
66
+ Dtype dtype,
67
+ StreamOrDevice s = {});
68
+ array full(
69
+ const std::vector<int>& shape,
70
+ const array& vals,
71
+ StreamOrDevice s = {});
72
+ template <typename T>
73
+ array full(
74
+ const std::vector<int>& shape,
75
+ T val,
76
+ Dtype dtype,
77
+ StreamOrDevice s = {}) {
78
+ return full(shape, array(val, dtype), to_stream(s));
79
+ }
80
+ template <typename T>
81
+ array full(const std::vector<int>& shape, T val, StreamOrDevice s = {}) {
82
+ return full(shape, array(val), to_stream(s));
83
+ }
84
+
85
+ /** Fill an array of the given shape with zeros. */
86
+ array zeros(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
87
+ inline array zeros(const std::vector<int>& shape, StreamOrDevice s = {}) {
88
+ return zeros(shape, float32, s);
89
+ }
90
+ array zeros_like(const array& a, StreamOrDevice s = {});
91
+
92
+ /** Fill an array of the given shape with ones. */
93
+ array ones(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
94
+ inline array ones(const std::vector<int>& shape, StreamOrDevice s = {}) {
95
+ return ones(shape, float32, s);
96
+ }
97
+ array ones_like(const array& a, StreamOrDevice s = {});
98
+
99
+ /** Fill an array of the given shape (n,m) with ones in the specified diagonal
100
+ * k, and zeros everywhere else. */
101
+ array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {});
102
+ inline array eye(int n, Dtype dtype, StreamOrDevice s = {}) {
103
+ return eye(n, n, 0, dtype, s);
104
+ }
105
+ inline array eye(int n, int m, StreamOrDevice s = {}) {
106
+ return eye(n, m, 0, float32, s);
107
+ }
108
+ inline array eye(int n, int m, int k, StreamOrDevice s = {}) {
109
+ return eye(n, m, k, float32, s);
110
+ }
111
+ inline array eye(int n, StreamOrDevice s = {}) {
112
+ return eye(n, n, 0, float32, s);
113
+ }
114
+
115
+ /** Create a square matrix of shape (n,n) of zeros, and ones in the major
116
+ * diagonal. */
117
+ array identity(int n, Dtype dtype, StreamOrDevice s = {});
118
+ inline array identity(int n, StreamOrDevice s = {}) {
119
+ return identity(n, float32, s);
120
+ }
121
+
122
+ array tri(int n, int m, int k, Dtype type, StreamOrDevice s = {});
123
+ inline array tri(int n, Dtype type, StreamOrDevice s = {}) {
124
+ return tri(n, n, 0, type, s);
125
+ }
126
+
127
+ array tril(array x, int k, StreamOrDevice s = {});
128
+ array triu(array x, int k, StreamOrDevice s = {});
129
+
130
+ /** array manipulation */
131
+
132
+ /** Reshape an array to the given shape. */
133
+ array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
134
+
135
+ /** Flatten the dimensions in the range `[start_axis, end_axis]` . */
136
+ array flatten(
137
+ const array& a,
138
+ int start_axis,
139
+ int end_axis = -1,
140
+ StreamOrDevice s = {});
141
+
142
+ /** Flatten the array to 1D. */
143
+ array flatten(const array& a, StreamOrDevice s = {});
144
+
145
+ /** Remove singleton dimensions at the given axes. */
146
+ array squeeze(
147
+ const array& a,
148
+ const std::vector<int>& axes,
149
+ StreamOrDevice s = {});
150
+
151
+ /** Remove singleton dimensions at the given axis. */
152
+ inline array squeeze(const array& a, int axis, StreamOrDevice s = {}) {
153
+ return squeeze(a, std::vector<int>{axis}, s);
154
+ }
155
+
156
+ /** Remove all singleton dimensions. */
157
+ array squeeze(const array& a, StreamOrDevice s = {});
158
+
159
+ /** Add a singleton dimension at the given axes. */
160
+ array expand_dims(
161
+ const array& a,
162
+ const std::vector<int>& axes,
163
+ StreamOrDevice s = {});
164
+
165
+ /** Add a singleton dimension at the given axis. */
166
+ inline array expand_dims(const array& a, int axis, StreamOrDevice s = {}) {
167
+ return expand_dims(a, std::vector<int>{axis}, s);
168
+ }
169
+
170
+ /** Slice an array. */
171
+ array slice(
172
+ const array& a,
173
+ std::vector<int> start,
174
+ std::vector<int> stop,
175
+ std::vector<int> strides,
176
+ StreamOrDevice s = {});
177
+
178
+ /** Slice an array with a stride of 1 in each dimension. */
179
+ array slice(
180
+ const array& a,
181
+ const std::vector<int>& start,
182
+ const std::vector<int>& stop,
183
+ StreamOrDevice s = {});
184
+
185
+ /** Split an array into sub-arrays along a given axis. */
186
+ std::vector<array>
187
+ split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
188
+ std::vector<array> split(const array& a, int num_splits, StreamOrDevice s = {});
189
+ std::vector<array> split(
190
+ const array& a,
191
+ const std::vector<int>& indices,
192
+ int axis,
193
+ StreamOrDevice s = {});
194
+ std::vector<array>
195
+ split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});
196
+
197
+ /**
198
+ * Clip (limit) the values in an array.
199
+ */
200
+ array clip(
201
+ const array& a,
202
+ const std::optional<array>& a_min = std::nullopt,
203
+ const std::optional<array>& a_max = std::nullopt,
204
+ StreamOrDevice s = {});
205
+
206
+ /** Concatenate arrays along a given axis. */
207
+ array concatenate(
208
+ const std::vector<array>& arrays,
209
+ int axis,
210
+ StreamOrDevice s = {});
211
+ array concatenate(const std::vector<array>& arrays, StreamOrDevice s = {});
212
+
213
+ /** Stack arrays along a new axis. */
214
+ array stack(const std::vector<array>& arrays, int axis, StreamOrDevice s = {});
215
+ array stack(const std::vector<array>& arrays, StreamOrDevice s = {});
216
+
217
+ /** Repeat an array along an axis. */
218
+ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {});
219
+ array repeat(const array& arr, int repeats, StreamOrDevice s = {});
220
+
221
+ /** Permutes the dimensions according to the given axes. */
222
+ array transpose(const array& a, std::vector<int> axes, StreamOrDevice s = {});
223
+ inline array transpose(
224
+ const array& a,
225
+ std::initializer_list<int> axes,
226
+ StreamOrDevice s = {}) {
227
+ return transpose(a, std::vector<int>(axes), s);
228
+ }
229
+
230
+ /** Swap two axes of an array. */
231
+ array swapaxes(const array& a, int axis1, int axis2, StreamOrDevice s = {});
232
+
233
+ /** Move an axis of an array. */
234
+ array moveaxis(
235
+ const array& a,
236
+ int source,
237
+ int destination,
238
+ StreamOrDevice s = {});
239
+
240
+ /** Pad an array with a constant value */
241
+ array pad(
242
+ const array& a,
243
+ const std::vector<int>& axes,
244
+ const std::vector<int>& low_pad_size,
245
+ const std::vector<int>& high_pad_size,
246
+ const array& pad_value = array(0),
247
+ StreamOrDevice s = {});
248
+
249
+ /** Pad an array with a constant value along all axes */
250
+ array pad(
251
+ const array& a,
252
+ const std::vector<std::pair<int, int>>& pad_width,
253
+ const array& pad_value = array(0),
254
+ StreamOrDevice s = {});
255
+ array pad(
256
+ const array& a,
257
+ const std::pair<int, int>& pad_width,
258
+ const array& pad_value = array(0),
259
+ StreamOrDevice s = {});
260
+ array pad(
261
+ const array& a,
262
+ int pad_width,
263
+ const array& pad_value = array(0),
264
+ StreamOrDevice s = {});
265
+
266
+ /** Permutes the dimensions in reverse order. */
267
+ array transpose(const array& a, StreamOrDevice s = {});
268
+
269
+ /** Broadcast an array to a given shape. */
270
+ array broadcast_to(
271
+ const array& a,
272
+ const std::vector<int>& shape,
273
+ StreamOrDevice s = {});
274
+
275
+ /** Broadcast a vector of arrays against one another. */
276
+ std::vector<array> broadcast_arrays(
277
+ const std::vector<array>& inputs,
278
+ StreamOrDevice s = {});
279
+
280
+ /** Comparison operations */
281
+
282
+ /** Returns the bool array with (a == b) element-wise. */
283
+ array equal(const array& a, const array& b, StreamOrDevice s = {});
284
+ inline array operator==(const array& a, const array& b) {
285
+ return equal(a, b);
286
+ }
287
+ template <typename T>
288
+ array operator==(T a, const array& b) {
289
+ return equal(array(a), b);
290
+ }
291
+ template <typename T>
292
+ array operator==(const array& a, T b) {
293
+ return equal(a, array(b));
294
+ }
295
+
296
+ /** Returns the bool array with (a != b) element-wise. */
297
+ array not_equal(const array& a, const array& b, StreamOrDevice s = {});
298
+ inline array operator!=(const array& a, const array& b) {
299
+ return not_equal(a, b);
300
+ }
301
+ template <typename T>
302
+ array operator!=(T a, const array& b) {
303
+ return not_equal(array(a), b);
304
+ }
305
+ template <typename T>
306
+ array operator!=(const array& a, T b) {
307
+ return not_equal(a, array(b));
308
+ }
309
+
310
+ /** Returns bool array with (a > b) element-wise. */
311
+ array greater(const array& a, const array& b, StreamOrDevice s = {});
312
+ inline array operator>(const array& a, const array& b) {
313
+ return greater(a, b);
314
+ }
315
+ template <typename T>
316
+ array operator>(T a, const array& b) {
317
+ return greater(array(a), b);
318
+ }
319
+ template <typename T>
320
+ array operator>(const array& a, T b) {
321
+ return greater(a, array(b));
322
+ }
323
+
324
+ /** Returns bool array with (a >= b) element-wise. */
325
+ array greater_equal(const array& a, const array& b, StreamOrDevice s = {});
326
+ inline array operator>=(const array& a, const array& b) {
327
+ return greater_equal(a, b);
328
+ }
329
+ template <typename T>
330
+ array operator>=(T a, const array& b) {
331
+ return greater_equal(array(a), b);
332
+ }
333
+ template <typename T>
334
+ array operator>=(const array& a, T b) {
335
+ return greater_equal(a, array(b));
336
+ }
337
+
338
+ /** Returns bool array with (a < b) element-wise. */
339
+ array less(const array& a, const array& b, StreamOrDevice s = {});
340
+ inline array operator<(const array& a, const array& b) {
341
+ return less(a, b);
342
+ }
343
+ template <typename T>
344
+ array operator<(T a, const array& b) {
345
+ return less(array(a), b);
346
+ }
347
+ template <typename T>
348
+ array operator<(const array& a, T b) {
349
+ return less(a, array(b));
350
+ }
351
+
352
+ /** Returns bool array with (a <= b) element-wise. */
353
+ array less_equal(const array& a, const array& b, StreamOrDevice s = {});
354
+ inline array operator<=(const array& a, const array& b) {
355
+ return less_equal(a, b);
356
+ }
357
+ template <typename T>
358
+ array operator<=(T a, const array& b) {
359
+ return less_equal(array(a), b);
360
+ }
361
+ template <typename T>
362
+ array operator<=(const array& a, T b) {
363
+ return less_equal(a, array(b));
364
+ }
365
+
366
+ /** True if two arrays have the same shape and elements. */
367
+ array array_equal(
368
+ const array& a,
369
+ const array& b,
370
+ bool equal_nan,
371
+ StreamOrDevice s = {});
372
+ inline array
373
+ array_equal(const array& a, const array& b, StreamOrDevice s = {}) {
374
+ return array_equal(a, b, false, s);
375
+ }
376
+
377
+ /** Select from x or y depending on condition. */
378
+ array where(
379
+ const array& condition,
380
+ const array& x,
381
+ const array& y,
382
+ StreamOrDevice s = {});
383
+
384
+ /** Reduction operations */
385
+
386
+ /** True if all elements in the array are true (or non-zero). **/
387
+ array all(const array& a, bool keepdims, StreamOrDevice s = {});
388
+ inline array all(const array& a, StreamOrDevice s = {}) {
389
+ return all(a, false, to_stream(s));
390
+ }
391
+
392
+ /** True if the two arrays are equal within the specified tolerance. */
393
+ array allclose(
394
+ const array& a,
395
+ const array& b,
396
+ double rtol = 1e-5,
397
+ double atol = 1e-8,
398
+ StreamOrDevice s = {});
399
+
400
+ /**
401
+ * Reduces the input along the given axes. An output value is true
402
+ * if all the corresponding inputs are true.
403
+ **/
404
+ array all(
405
+ const array& a,
406
+ const std::vector<int>& axes,
407
+ bool keepdims = false,
408
+ StreamOrDevice s = {});
409
+
410
+ /**
411
+ * Reduces the input along the given axis. An output value is true
412
+ * if all the corresponding inputs are true.
413
+ **/
414
+ array all(
415
+ const array& a,
416
+ int axis,
417
+ bool keepdims = false,
418
+ StreamOrDevice s = {});
419
+
420
+ /** True if any elements in the array are true (or non-zero). **/
421
+ array any(const array& a, bool keepdims, StreamOrDevice s = {});
422
+ inline array any(const array& a, StreamOrDevice s = {}) {
423
+ return any(a, false, to_stream(s));
424
+ }
425
+
426
+ /**
427
+ * Reduces the input along the given axes. An output value is true
428
+ * if any of the corresponding inputs are true.
429
+ **/
430
+ array any(
431
+ const array& a,
432
+ const std::vector<int>& axes,
433
+ bool keepdims = false,
434
+ StreamOrDevice s = {});
435
+
436
+ /**
437
+ * Reduces the input along the given axis. An output value is true
438
+ * if any of the corresponding inputs are true.
439
+ **/
440
+ array any(
441
+ const array& a,
442
+ int axis,
443
+ bool keepdims = false,
444
+ StreamOrDevice s = {});
445
+
446
+ /** Sums the elements of an array. */
447
+ array sum(const array& a, bool keepdims, StreamOrDevice s = {});
448
+ inline array sum(const array& a, StreamOrDevice s = {}) {
449
+ return sum(a, false, to_stream(s));
450
+ }
451
+
452
+ /** Sums the elements of an array along the given axes. */
453
+ array sum(
454
+ const array& a,
455
+ const std::vector<int>& axes,
456
+ bool keepdims = false,
457
+ StreamOrDevice s = {});
458
+
459
+ /** Sums the elements of an array along the given axis. */
460
+ array sum(
461
+ const array& a,
462
+ int axis,
463
+ bool keepdims = false,
464
+ StreamOrDevice s = {});
465
+
466
+ /** Computes the mean of the elements of an array. */
467
+ array mean(const array& a, bool keepdims, StreamOrDevice s = {});
468
+ inline array mean(const array& a, StreamOrDevice s = {}) {
469
+ return mean(a, false, to_stream(s));
470
+ }
471
+
472
+ /** Computes the mean of the elements of an array along the given axes */
473
+ array mean(
474
+ const array& a,
475
+ const std::vector<int>& axes,
476
+ bool keepdims = false,
477
+ StreamOrDevice s = {});
478
+
479
+ /** Computes the mean of the elements of an array along the given axis */
480
+ array mean(
481
+ const array& a,
482
+ int axis,
483
+ bool keepdims = false,
484
+ StreamOrDevice s = {});
485
+
486
+ /** Computes the mean of the elements of an array. */
487
+ array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
488
+ inline array var(const array& a, StreamOrDevice s = {}) {
489
+ return var(a, false, 0, to_stream(s));
490
+ }
491
+
492
+ /** Computes the var of the elements of an array along the given axes */
493
+ array var(
494
+ const array& a,
495
+ const std::vector<int>& axes,
496
+ bool keepdims = false,
497
+ int ddof = 0,
498
+ StreamOrDevice s = {});
499
+
500
+ /** Computes the var of the elements of an array along the given axis */
501
+ array var(
502
+ const array& a,
503
+ int axis,
504
+ bool keepdims = false,
505
+ int ddof = 0,
506
+ StreamOrDevice s = {});
507
+
508
+ /** The product of all elements of the array. */
509
+ array prod(const array& a, bool keepdims, StreamOrDevice s = {});
510
+ inline array prod(const array& a, StreamOrDevice s = {}) {
511
+ return prod(a, false, to_stream(s));
512
+ }
513
+
514
+ /** The product of the elements of an array along the given axes. */
515
+ array prod(
516
+ const array& a,
517
+ const std::vector<int>& axes,
518
+ bool keepdims = false,
519
+ StreamOrDevice s = {});
520
+
521
+ /** The product of the elements of an array along the given axis. */
522
+ array prod(
523
+ const array& a,
524
+ int axis,
525
+ bool keepdims = false,
526
+ StreamOrDevice s = {});
527
+
528
+ /** The maximum of all elements of the array. */
529
+ array max(const array& a, bool keepdims, StreamOrDevice s = {});
530
+ inline array max(const array& a, StreamOrDevice s = {}) {
531
+ return max(a, false, to_stream(s));
532
+ }
533
+
534
+ /** The maximum of the elements of an array along the given axes. */
535
+ array max(
536
+ const array& a,
537
+ const std::vector<int>& axes,
538
+ bool keepdims = false,
539
+ StreamOrDevice s = {});
540
+
541
+ /** The maximum of the elements of an array along the given axis. */
542
+ array max(
543
+ const array& a,
544
+ int axis,
545
+ bool keepdims = false,
546
+ StreamOrDevice s = {});
547
+
548
+ /** The minimum of all elements of the array. */
549
+ array min(const array& a, bool keepdims, StreamOrDevice s = {});
550
+ inline array min(const array& a, StreamOrDevice s = {}) {
551
+ return min(a, false, to_stream(s));
552
+ }
553
+
554
+ /** The minimum of the elements of an array along the given axes. */
555
+ array min(
556
+ const array& a,
557
+ const std::vector<int>& axes,
558
+ bool keepdims = false,
559
+ StreamOrDevice s = {});
560
+
561
+ /** The minimum of the elements of an array along the given axis. */
562
+ array min(
563
+ const array& a,
564
+ int axis,
565
+ bool keepdims = false,
566
+ StreamOrDevice s = {});
567
+
568
+ /** Returns the index of the minimum value in the array. */
569
+ array argmin(const array& a, bool keepdims, StreamOrDevice s = {});
570
+ inline array argmin(const array& a, StreamOrDevice s = {}) {
571
+ return argmin(a, false, s);
572
+ }
573
+
574
+ /** Returns the indices of the minimum values along a given axis. */
575
+ array argmin(
576
+ const array& a,
577
+ int axis,
578
+ bool keepdims = false,
579
+ StreamOrDevice s = {});
580
+
581
+ /** Returns the index of the maximum value in the array. */
582
+ array argmax(const array& a, bool keepdims, StreamOrDevice s = {});
583
+ inline array argmax(const array& a, StreamOrDevice s = {}) {
584
+ return argmax(a, false, s);
585
+ }
586
+
587
+ /** Returns the indices of the maximum values along a given axis. */
588
+ array argmax(
589
+ const array& a,
590
+ int axis,
591
+ bool keepdims = false,
592
+ StreamOrDevice s = {});
593
+
594
+ /** Returns a sorted copy of the flattened array. */
595
+ array sort(const array& a, StreamOrDevice s = {});
596
+
597
+ /** Returns a sorted copy of the array along a given axis. */
598
+ array sort(const array& a, int axis, StreamOrDevice s = {});
599
+
600
+ /** Returns indices that sort the flattened array. */
601
+ array argsort(const array& a, StreamOrDevice s = {});
602
+
603
+ /** Returns indices that sort the array along a given axis. */
604
+ array argsort(const array& a, int axis, StreamOrDevice s = {});
605
+
606
+ /**
607
+ * Returns a partitioned copy of the flattened array
608
+ * such that the smaller kth elements are first.
609
+ **/
610
+ array partition(const array& a, int kth, StreamOrDevice s = {});
611
+
612
+ /**
613
+ * Returns a partitioned copy of the array along a given axis
614
+ * such that the smaller kth elements are first.
615
+ **/
616
+ array partition(const array& a, int kth, int axis, StreamOrDevice s = {});
617
+
618
+ /**
619
+ * Returns indices that partition the flattened array
620
+ * such that the smaller kth elements are first.
621
+ **/
622
+ array argpartition(const array& a, int kth, StreamOrDevice s = {});
623
+
624
+ /**
625
+ * Returns indices that partition the array along a given axis
626
+ * such that the smaller kth elements are first.
627
+ **/
628
+ array argpartition(const array& a, int kth, int axis, StreamOrDevice s = {});
629
+
630
+ /** Returns topk elements of the flattened array. */
631
+ array topk(const array& a, int k, StreamOrDevice s = {});
632
+
633
+ /** Returns topk elements of the array along a given axis. */
634
+ array topk(const array& a, int k, int axis, StreamOrDevice s = {});
635
+
636
+ /** The logsumexp of all elements of the array. */
637
+ array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {});
638
+ inline array logsumexp(const array& a, StreamOrDevice s = {}) {
639
+ return logsumexp(a, false, to_stream(s));
640
+ }
641
+
642
+ /** The logsumexp of the elements of an array along the given axes. */
643
+ array logsumexp(
644
+ const array& a,
645
+ const std::vector<int>& axes,
646
+ bool keepdims = false,
647
+ StreamOrDevice s = {});
648
+
649
+ /** The logsumexp of the elements of an array along the given axis. */
650
+ array logsumexp(
651
+ const array& a,
652
+ int axis,
653
+ bool keepdims = false,
654
+ StreamOrDevice s = {});
655
+
656
+ /** Simple arithmetic operations */
657
+
658
+ /** Absolute value of elements in an array. */
659
+ array abs(const array& a, StreamOrDevice s = {});
660
+
661
+ /** Negate an array. */
662
+ array negative(const array& a, StreamOrDevice s = {});
663
+ array operator-(const array& a);
664
+
665
+ /** The sign of the elements in an array. */
666
+ array sign(const array& a, StreamOrDevice s = {});
667
+
668
+ /** Logical not of an array */
669
+ array logical_not(const array& a, StreamOrDevice s = {});
670
+
671
+ /** The reciprocal (1/x) of the elements in an array. */
672
+ array reciprocal(const array& a, StreamOrDevice s = {});
673
+
674
+ /** Add two arrays. */
675
+ array add(const array& a, const array& b, StreamOrDevice s = {});
676
+ array operator+(const array& a, const array& b);
677
+ template <typename T>
678
+ array operator+(T a, const array& b) {
679
+ return add(array(a), b);
680
+ }
681
+ template <typename T>
682
+ array operator+(const array& a, T b) {
683
+ return add(a, array(b));
684
+ }
685
+
686
+ /** Subtract two arrays. */
687
+ array subtract(const array& a, const array& b, StreamOrDevice s = {});
688
+ array operator-(const array& a, const array& b);
689
+ template <typename T>
690
+ array operator-(T a, const array& b) {
691
+ return subtract(array(a), b);
692
+ }
693
+ template <typename T>
694
+ array operator-(const array& a, T b) {
695
+ return subtract(a, array(b));
696
+ }
697
+
698
+ /** Multiply two arrays. */
699
+ array multiply(const array& a, const array& b, StreamOrDevice s = {});
700
+ array operator*(const array& a, const array& b);
701
+ template <typename T>
702
+ array operator*(T a, const array& b) {
703
+ return multiply(array(a), b);
704
+ }
705
+ template <typename T>
706
+ array operator*(const array& a, T b) {
707
+ return multiply(a, array(b));
708
+ }
709
+
710
+ /** Divide two arrays. */
711
+ array divide(const array& a, const array& b, StreamOrDevice s = {});
712
+ array operator/(const array& a, const array& b);
713
+ array operator/(double a, const array& b);
714
+ array operator/(const array& a, double b);
715
+
716
+ /** Compute integer division. Equivalent to doing floor(a / x). */
717
+ array floor_divide(const array& a, const array& b, StreamOrDevice s = {});
718
+
719
+ /** Compute the element-wise remainder of division */
720
+ array remainder(const array& a, const array& b, StreamOrDevice s = {});
721
+ array operator%(const array& a, const array& b);
722
+ template <typename T>
723
+ array operator%(T a, const array& b) {
724
+ return remainder(array(a), b);
725
+ }
726
+ template <typename T>
727
+ array operator%(const array& a, T b) {
728
+ return remainder(a, array(b));
729
+ }
730
+
731
+ /** Element-wise maximum between two arrays. */
732
+ array maximum(const array& a, const array& b, StreamOrDevice s = {});
733
+
734
+ /** Element-wise minimum between two arrays. */
735
+ array minimum(const array& a, const array& b, StreamOrDevice s = {});
736
+
737
+ /** Floor the element of an array. **/
738
+ array floor(const array& a, StreamOrDevice s = {});
739
+
740
+ /** Ceil the element of an array. **/
741
+ array ceil(const array& a, StreamOrDevice s = {});
742
+
743
+ /** Square the elements of an array. */
744
+ array square(const array& a, StreamOrDevice s = {});
745
+
746
+ /** Exponential of the elements of an array. */
747
+ array exp(const array& a, StreamOrDevice s = {});
748
+
749
+ /** Sine of the elements of an array */
750
+ array sin(const array& a, StreamOrDevice s = {});
751
+
752
+ /** Cosine of the elements of an array */
753
+ array cos(const array& a, StreamOrDevice s = {});
754
+
755
+ /** Tangent of the elements of an array */
756
+ array tan(const array& a, StreamOrDevice s = {});
757
+
758
+ /** Arc Sine of the elements of an array */
759
+ array arcsin(const array& a, StreamOrDevice s = {});
760
+
761
+ /** Arc Cosine of the elements of an array */
762
+ array arccos(const array& a, StreamOrDevice s = {});
763
+
764
+ /** Arc Tangent of the elements of an array */
765
+ array arctan(const array& a, StreamOrDevice s = {});
766
+
767
+ /** Hyperbolic Sine of the elements of an array */
768
+ array sinh(const array& a, StreamOrDevice s = {});
769
+
770
+ /** Hyperbolic Cosine of the elements of an array */
771
+ array cosh(const array& a, StreamOrDevice s = {});
772
+
773
+ /** Hyperbolic Tangent of the elements of an array */
774
+ array tanh(const array& a, StreamOrDevice s = {});
775
+
776
+ /** Inverse Hyperbolic Sine of the elements of an array */
777
+ array arcsinh(const array& a, StreamOrDevice s = {});
778
+
779
+ /** Inverse Hyperbolic Cosine of the elements of an array */
780
+ array arccosh(const array& a, StreamOrDevice s = {});
781
+
782
+ /** Inverse Hyperbolic Tangent of the elements of an array */
783
+ array arctanh(const array& a, StreamOrDevice s = {});
784
+
785
+ /** Natural logarithm of the elements of an array. */
786
+ array log(const array& a, StreamOrDevice s = {});
787
+
788
+ /** Log base 2 of the elements of an array. */
789
+ array log2(const array& a, StreamOrDevice s = {});
790
+
791
+ /** Log base 10 of the elements of an array. */
792
+ array log10(const array& a, StreamOrDevice s = {});
793
+
794
+ /** Natural logarithm of one plus elements in the array: `log(1 + a)`. */
795
+ array log1p(const array& a, StreamOrDevice s = {});
796
+
797
+ /** Log-add-exp of one elements in the array: `log(exp(a) + exp(b))`. */
798
+ array logaddexp(const array& a, const array& b, StreamOrDevice s = {});
799
+
800
+ /** Element-wise logistic sigmoid of the array: `1 / (1 + exp(-x)`. */
801
+ array sigmoid(const array& a, StreamOrDevice s = {});
802
+
803
+ /** Computes the error function of the elements of an array. */
804
+ array erf(const array& a, StreamOrDevice s = {});
805
+
806
+ /** Computes the inverse error function of the elements of an array. */
807
+ array erfinv(const array& a, StreamOrDevice s = {});
808
+
809
+ /** Stop the flow of gradients. */
810
+ array stop_gradient(const array& a, StreamOrDevice s = {});
811
+
812
+ /** Round a floating point number */
813
+ array round(const array& a, int decimals, StreamOrDevice s = {});
814
+ inline array round(const array& a, StreamOrDevice s = {}) {
815
+ return round(a, 0, s);
816
+ }
817
+
818
+ /** Matrix-matrix multiplication. */
819
+ array matmul(const array& a, const array& b, StreamOrDevice s = {});
820
+
821
+ /** Gather array entries given indices and slices */
822
+ array gather(
823
+ const array& a,
824
+ const std::vector<array>& indices,
825
+ const std::vector<int>& axes,
826
+ const std::vector<int>& slice_sizes,
827
+ StreamOrDevice s = {});
828
+ inline array gather(
829
+ const array& a,
830
+ const array& indices,
831
+ int axis,
832
+ const std::vector<int>& slice_sizes,
833
+ StreamOrDevice s = {}) {
834
+ return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
835
+ }
836
+
837
+ /** Take array slices at the given indices of the specified axis. */
838
+ array take(
839
+ const array& a,
840
+ const array& indices,
841
+ int axis,
842
+ StreamOrDevice s = {});
843
+
844
+ /** Take array entries at the given indices treating the array as flattened. */
845
+ array take(const array& a, const array& indices, StreamOrDevice s = {});
846
+
847
+ /** Take array entries given indices along the axis */
848
+ array take_along_axis(
849
+ const array& a,
850
+ const array& indices,
851
+ int axis,
852
+ StreamOrDevice s = {});
853
+
854
+ /** Scatter updates to given linear indices */
855
+ array scatter(
856
+ const array& a,
857
+ const std::vector<array>& indices,
858
+ const array& updates,
859
+ const std::vector<int>& axes,
860
+ StreamOrDevice s = {});
861
+ inline array scatter(
862
+ const array& a,
863
+ const array& indices,
864
+ const array& updates,
865
+ int axis,
866
+ StreamOrDevice s = {}) {
867
+ return scatter(a, {indices}, updates, std::vector<int>{axis}, s);
868
+ }
869
+
870
+ /** Scatter and add updates to given indices */
871
+ array scatter_add(
872
+ const array& a,
873
+ const std::vector<array>& indices,
874
+ const array& updates,
875
+ const std::vector<int>& axes,
876
+ StreamOrDevice s = {});
877
+ inline array scatter_add(
878
+ const array& a,
879
+ const array& indices,
880
+ const array& updates,
881
+ int axis,
882
+ StreamOrDevice s = {}) {
883
+ return scatter_add(a, {indices}, updates, std::vector<int>{axis}, s);
884
+ }
885
+
886
+ /** Scatter and prod updates to given indices */
887
+ array scatter_prod(
888
+ const array& a,
889
+ const std::vector<array>& indices,
890
+ const array& updates,
891
+ const std::vector<int>& axes,
892
+ StreamOrDevice s = {});
893
+ inline array scatter_prod(
894
+ const array& a,
895
+ const array& indices,
896
+ const array& updates,
897
+ int axis,
898
+ StreamOrDevice s = {}) {
899
+ return scatter_prod(a, {indices}, updates, std::vector<int>{axis}, s);
900
+ }
901
+
902
+ /** Scatter and max updates to given linear indices */
903
+ array scatter_max(
904
+ const array& a,
905
+ const std::vector<array>& indices,
906
+ const array& updates,
907
+ const std::vector<int>& axes,
908
+ StreamOrDevice s = {});
909
+ inline array scatter_max(
910
+ const array& a,
911
+ const array& indices,
912
+ const array& updates,
913
+ int axis,
914
+ StreamOrDevice s = {}) {
915
+ return scatter_max(a, {indices}, updates, std::vector<int>{axis}, s);
916
+ }
917
+ /** Scatter and min updates to given linear indices */
918
+ array scatter_min(
919
+ const array& a,
920
+ const std::vector<array>& indices,
921
+ const array& updates,
922
+ const std::vector<int>& axes,
923
+ StreamOrDevice s = {});
924
+ inline array scatter_min(
925
+ const array& a,
926
+ const array& indices,
927
+ const array& updates,
928
+ int axis,
929
+ StreamOrDevice s = {}) {
930
+ return scatter_min(a, {indices}, updates, std::vector<int>{axis}, s);
931
+ }
932
+
933
+ /** Square root the elements of an array. */
934
+ array sqrt(const array& a, StreamOrDevice s = {});
935
+
936
+ /** Square root and reciprocal the elements of an array. */
937
+ array rsqrt(const array& a, StreamOrDevice s = {});
938
+
939
+ /** Softmax of an array. */
940
+ array softmax(
941
+ const array& a,
942
+ const std::vector<int>& axes,
943
+ StreamOrDevice s = {});
944
+
945
+ /** Softmax of an array. */
946
+ array softmax(const array& a, StreamOrDevice s = {});
947
+
948
+ /** Softmax of an array. */
949
+ inline array softmax(const array& a, int axis, StreamOrDevice s = {}) {
950
+ return softmax(a, std::vector<int>{axis}, s);
951
+ }
952
+
953
+ /** Raise elements of a to the power of b element-wise */
954
+ array power(const array& a, const array& b, StreamOrDevice s = {});
955
+ inline array operator^(const array& a, const array& b) {
956
+ return power(a, b);
957
+ }
958
+ template <typename T>
959
+ array operator^(T a, const array& b) {
960
+ return power(array(a), b);
961
+ }
962
+ template <typename T>
963
+ array operator^(const array& a, T b) {
964
+ return power(a, array(b));
965
+ }
966
+
967
+ /** Cumulative sum of an array. */
968
+ array cumsum(
969
+ const array& a,
970
+ int axis,
971
+ bool reverse = false,
972
+ bool inclusive = true,
973
+ StreamOrDevice s = {});
974
+
975
+ /** Cumulative product of an array. */
976
+ array cumprod(
977
+ const array& a,
978
+ int axis,
979
+ bool reverse = false,
980
+ bool inclusive = true,
981
+ StreamOrDevice s = {});
982
+
983
+ /** Cumulative max of an array. */
984
+ array cummax(
985
+ const array& a,
986
+ int axis,
987
+ bool reverse = false,
988
+ bool inclusive = true,
989
+ StreamOrDevice s = {});
990
+
991
+ /** Cumulative min of an array. */
992
+ array cummin(
993
+ const array& a,
994
+ int axis,
995
+ bool reverse = false,
996
+ bool inclusive = true,
997
+ StreamOrDevice s = {});
998
+
999
+ /** Convolution operations */
1000
+
1001
+ /** 1D convolution with a filter */
1002
+ array conv1d(
1003
+ const array& input,
1004
+ const array& weight,
1005
+ int stride = 1,
1006
+ int padding = 0,
1007
+ int dilation = 1,
1008
+ int groups = 1,
1009
+ StreamOrDevice s = {});
1010
+
1011
+ /** 2D convolution with a filter */
1012
+ array conv2d(
1013
+ const array& input,
1014
+ const array& weight,
1015
+ const std::pair<int, int>& stride = {1, 1},
1016
+ const std::pair<int, int>& padding = {0, 0},
1017
+ const std::pair<int, int>& dilation = {1, 1},
1018
+ int groups = 1,
1019
+ StreamOrDevice s = {});
1020
+
1021
+ /** Serialization operations */
1022
+
1023
+ /** Save array to out stream in .npy format */
1024
+ void save(
1025
+ std::shared_ptr<io::Writer> out_stream,
1026
+ array a,
1027
+ bool retain_graph = true);
1028
+
1029
+ /** Save array to file in .npy format */
1030
+ void save(const std::string& file, array a, bool retain_graph = true);
1031
+
1032
+ /** Load array from reader in .npy format */
1033
+ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
1034
+
1035
+ /** Load array from file in .npy format */
1036
+ array load(const std::string& file, StreamOrDevice s = {});
1037
+
1038
+ /** Quantized matmul multiplies x with a quantized matrix w*/
1039
+ array quantized_matmul(
1040
+ const array& x,
1041
+ const array& w,
1042
+ const array& scales,
1043
+ const array& biases,
1044
+ bool transpose = true,
1045
+ int group_size = 64,
1046
+ int bits = 4,
1047
+ StreamOrDevice s = {});
1048
+
1049
+ /** Quantize a matrix along its last axis */
1050
+ std::tuple<array, array, array> quantize(
1051
+ const array& w,
1052
+ int group_size = 64,
1053
+ int bits = 4,
1054
+ StreamOrDevice s = {});
1055
+
1056
+ /** Dequantize a matrix produced by quantize() */
1057
+ array dequantize(
1058
+ const array& w,
1059
+ const array& scales,
1060
+ const array& biases,
1061
+ int group_size = 64,
1062
+ int bits = 4,
1063
+ StreamOrDevice s = {});
1064
+
1065
+ /** TensorDot returns a contraction of a and b over multiple dimensions. */
1066
+ array tensordot(
1067
+ const array& a,
1068
+ const array& b,
1069
+ const int dims = 2,
1070
+ StreamOrDevice s = {});
1071
+
1072
+ array tensordot(
1073
+ const array& a,
1074
+ const array& b,
1075
+ const std::pair<std::vector<int>, std::vector<int>>& dims,
1076
+ StreamOrDevice s = {});
1077
+
1078
+ /** Load array map from .safetensors file format */
1079
+ std::unordered_map<std::string, array> load_safetensors(
1080
+ std::shared_ptr<io::Reader> in_stream,
1081
+ StreamOrDevice s = {});
1082
+ std::unordered_map<std::string, array> load_safetensors(
1083
+ const std::string& file,
1084
+ StreamOrDevice s = {});
1085
+
1086
+ void save_safetensors(
1087
+ std::shared_ptr<io::Writer> in_stream,
1088
+ std::unordered_map<std::string, array>,
1089
+ std::optional<bool> retain_graph = std::nullopt);
1090
+ void save_safetensors(
1091
+ const std::string& file,
1092
+ std::unordered_map<std::string, array>,
1093
+ std::optional<bool> retain_graph = std::nullopt);
1094
+ } // namespace mlx::core
lib/python3.11/site-packages/mlx/include/mlx/primitives.h ADDED
@@ -0,0 +1,1636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "array.h"
6
+ #include "device.h"
7
+ #include "io/load.h"
8
+ #include "stream.h"
9
+
10
+ #define DEFINE_GRADS() \
11
+ array jvp( \
12
+ const std::vector<array>& primals, \
13
+ const std::vector<array>& tangents, \
14
+ const std::vector<int>& argnums) override; \
15
+ \
16
+ std::vector<array> vjp( \
17
+ const std::vector<array>& primals, \
18
+ const array& cotan, \
19
+ const std::vector<int>& argnums) override;
20
+
21
+ #define DEFINE_PRINT(PRIMITIVE) \
22
+ void print(std::ostream& os) override { \
23
+ os << #PRIMITIVE; \
24
+ }
25
+
26
+ #define DEFINE_DEFAULT_IS_EQUIVALENT() \
27
+ bool is_equivalent(const Primitive& other) const override { \
28
+ return true; \
29
+ }
30
+
31
+ namespace mlx::core {
32
+
33
+ // Abstract base class
34
+ class Primitive {
35
+ public:
36
+ explicit Primitive(Stream stream) : stream_(stream) {}
37
+
38
+ /** The device the primitive will run on. */
39
+ const Device& device() {
40
+ return stream().device;
41
+ }
42
+
43
+ /** The stream the primitive will run on. */
44
+ const Stream& stream() {
45
+ return stream_;
46
+ }
47
+
48
+ /**
49
+ * A primitive must know how to evaluate itself on
50
+ * the CPU/GPU for the given inputs and populate the output array.
51
+ *
52
+ * To avoid unnecessary allocations, the evaluation function
53
+ * is responsible for allocating space for the array.
54
+ */
55
+ virtual void eval_cpu(const std::vector<array>& inputs, array& out) = 0;
56
+ virtual void eval_gpu(const std::vector<array>& inputs, array& out) = 0;
57
+
58
+ /**
59
+ * The Jacobian-vector product.
60
+ */
61
+ virtual array jvp(
62
+ const std::vector<array>& primals,
63
+ const std::vector<array>& tangents,
64
+ const std::vector<int>& argnums);
65
+
66
+ /**
67
+ * The vector-Jacobian product.
68
+ */
69
+ virtual std::vector<array> vjp(
70
+ const std::vector<array>& primals,
71
+ const array& cotan,
72
+ const std::vector<int>& argnums);
73
+
74
+ /**
75
+ * The primitive must know how to vectorize itself across
76
+ * the given axes. The output is a pair containing the array
77
+ * representing the vectorized computation and the axis which
78
+ * corresponds to the output vectorized dimension.
79
+ */
80
+ virtual std::pair<array, int> vmap(
81
+ const std::vector<array>& inputs,
82
+ const std::vector<int>& axes);
83
+
84
+ /** Print the primitive. */
85
+ virtual void print(std::ostream& os) = 0;
86
+
87
+ /** Equivalence check defaults to false unless overridden by the primitive */
88
+ virtual bool is_equivalent(const Primitive& other) const {
89
+ return false;
90
+ }
91
+
92
+ virtual ~Primitive() = default;
93
+ Primitive(const Primitive& other) = delete;
94
+ Primitive(Primitive&& other) = delete;
95
+ Primitive& operator=(const Primitive& other) = delete;
96
+ Primitive& operator=(Primitive&& other) = delete;
97
+
98
+ private:
99
+ // Every primitive stores the stream it should run in
100
+ Stream stream_;
101
+ };
102
+
103
+ class Abs : public Primitive {
104
+ public:
105
+ explicit Abs(Stream stream) : Primitive(stream){};
106
+
107
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
108
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
109
+
110
+ std::pair<array, int> vmap(
111
+ const std::vector<array>& inputs,
112
+ const std::vector<int>& axes) override;
113
+
114
+ DEFINE_GRADS()
115
+ DEFINE_PRINT(Abs)
116
+ DEFINE_DEFAULT_IS_EQUIVALENT()
117
+
118
+ private:
119
+ void eval(const std::vector<array>& inputs, array& out);
120
+ };
121
+
122
+ class Add : public Primitive {
123
+ public:
124
+ explicit Add(Stream stream) : Primitive(stream){};
125
+
126
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
127
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
128
+
129
+ std::pair<array, int> vmap(
130
+ const std::vector<array>& inputs,
131
+ const std::vector<int>& axes) override;
132
+
133
+ DEFINE_GRADS()
134
+ DEFINE_PRINT(Add)
135
+ DEFINE_DEFAULT_IS_EQUIVALENT()
136
+
137
+ private:
138
+ void eval(const std::vector<array>& inputs, array& out);
139
+ };
140
+
141
+ class Arange : public Primitive {
142
+ public:
143
+ explicit Arange(Stream stream, double start, double stop, double step)
144
+ : Primitive(stream), start_(start), stop_(stop), step_(step){};
145
+
146
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
147
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
148
+
149
+ DEFINE_PRINT(Arange)
150
+ bool is_equivalent(const Primitive& other) const override;
151
+
152
+ private:
153
+ double start_;
154
+ double stop_;
155
+ double step_;
156
+
157
+ void eval(const std::vector<array>& inputs, array& out);
158
+ };
159
+
160
+ class ArcCos : public Primitive {
161
+ public:
162
+ explicit ArcCos(Stream stream) : Primitive(stream){};
163
+
164
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
165
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
166
+
167
+ std::pair<array, int> vmap(
168
+ const std::vector<array>& inputs,
169
+ const std::vector<int>& axes) override;
170
+
171
+ DEFINE_GRADS()
172
+ DEFINE_PRINT(ArcCos)
173
+ DEFINE_DEFAULT_IS_EQUIVALENT()
174
+
175
+ private:
176
+ void eval(const std::vector<array>& inputs, array& out);
177
+ };
178
+
179
+ class ArcCosh : public Primitive {
180
+ public:
181
+ explicit ArcCosh(Stream stream) : Primitive(stream){};
182
+
183
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
184
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
185
+
186
+ std::pair<array, int> vmap(
187
+ const std::vector<array>& inputs,
188
+ const std::vector<int>& axes) override;
189
+
190
+ DEFINE_GRADS()
191
+ DEFINE_PRINT(ArcCosh)
192
+ DEFINE_DEFAULT_IS_EQUIVALENT()
193
+
194
+ private:
195
+ void eval(const std::vector<array>& inputs, array& out);
196
+ };
197
+
198
+ class ArcSin : public Primitive {
199
+ public:
200
+ explicit ArcSin(Stream stream) : Primitive(stream){};
201
+
202
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
203
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
204
+
205
+ std::pair<array, int> vmap(
206
+ const std::vector<array>& inputs,
207
+ const std::vector<int>& axes) override;
208
+
209
+ DEFINE_GRADS()
210
+ DEFINE_PRINT(ArcSin)
211
+ DEFINE_DEFAULT_IS_EQUIVALENT()
212
+
213
+ private:
214
+ void eval(const std::vector<array>& inputs, array& out);
215
+ };
216
+
217
+ class ArcSinh : public Primitive {
218
+ public:
219
+ explicit ArcSinh(Stream stream) : Primitive(stream){};
220
+
221
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
222
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
223
+
224
+ std::pair<array, int> vmap(
225
+ const std::vector<array>& inputs,
226
+ const std::vector<int>& axes) override;
227
+
228
+ DEFINE_GRADS()
229
+ DEFINE_PRINT(ArcSinh)
230
+ DEFINE_DEFAULT_IS_EQUIVALENT()
231
+
232
+ private:
233
+ void eval(const std::vector<array>& inputs, array& out);
234
+ };
235
+
236
+ class ArcTan : public Primitive {
237
+ public:
238
+ explicit ArcTan(Stream stream) : Primitive(stream){};
239
+
240
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
241
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
242
+
243
+ std::pair<array, int> vmap(
244
+ const std::vector<array>& inputs,
245
+ const std::vector<int>& axes) override;
246
+
247
+ DEFINE_GRADS()
248
+ DEFINE_PRINT(ArcTan)
249
+ DEFINE_DEFAULT_IS_EQUIVALENT()
250
+
251
+ private:
252
+ void eval(const std::vector<array>& inputs, array& out);
253
+ };
254
+
255
+ class ArcTanh : public Primitive {
256
+ public:
257
+ explicit ArcTanh(Stream stream) : Primitive(stream){};
258
+
259
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
260
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
261
+
262
+ std::pair<array, int> vmap(
263
+ const std::vector<array>& inputs,
264
+ const std::vector<int>& axes) override;
265
+
266
+ DEFINE_GRADS()
267
+ DEFINE_PRINT(ArcTanh)
268
+ DEFINE_DEFAULT_IS_EQUIVALENT()
269
+
270
+ private:
271
+ void eval(const std::vector<array>& inputs, array& out);
272
+ };
273
+
274
+ class ArgPartition : public Primitive {
275
+ public:
276
+ explicit ArgPartition(Stream stream, int kth, int axis)
277
+ : Primitive(stream), kth_(kth), axis_(axis){};
278
+
279
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
280
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
281
+
282
+ std::pair<array, int> vmap(
283
+ const std::vector<array>& inputs,
284
+ const std::vector<int>& axes) override;
285
+
286
+ DEFINE_PRINT(ArgPartition)
287
+ bool is_equivalent(const Primitive& other) const override;
288
+
289
+ private:
290
+ int kth_;
291
+ int axis_;
292
+
293
+ void eval(const std::vector<array>& inputs, array& out);
294
+ };
295
+
296
+ class ArgReduce : public Primitive {
297
+ public:
298
+ enum ReduceType {
299
+ ArgMin,
300
+ ArgMax,
301
+ };
302
+
303
+ explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis)
304
+ : Primitive(stream), reduce_type_(reduce_type), axis_(axis){};
305
+
306
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
307
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
308
+
309
+ DEFINE_PRINT(ArgReduce)
310
+ bool is_equivalent(const Primitive& other) const override;
311
+
312
+ private:
313
+ ReduceType reduce_type_;
314
+ int axis_;
315
+
316
+ void eval(const std::vector<array>& inputs, array& out);
317
+ };
318
+
319
+ class ArgSort : public Primitive {
320
+ public:
321
+ explicit ArgSort(Stream stream, int axis) : Primitive(stream), axis_(axis){};
322
+
323
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
324
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
325
+
326
+ std::pair<array, int> vmap(
327
+ const std::vector<array>& inputs,
328
+ const std::vector<int>& axes) override;
329
+
330
+ DEFINE_PRINT(ArgSort)
331
+ bool is_equivalent(const Primitive& other) const override;
332
+
333
+ private:
334
+ int axis_;
335
+
336
+ void eval(const std::vector<array>& inputs, array& out);
337
+ };
338
+
339
+ class AsType : public Primitive {
340
+ public:
341
+ explicit AsType(Stream stream, Dtype dtype)
342
+ : Primitive(stream), dtype_(dtype){};
343
+
344
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
345
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
346
+
347
+ std::pair<array, int> vmap(
348
+ const std::vector<array>& inputs,
349
+ const std::vector<int>& axes) override;
350
+
351
+ DEFINE_GRADS()
352
+ DEFINE_PRINT(AsType)
353
+ bool is_equivalent(const Primitive& other) const override;
354
+
355
+ private:
356
+ Dtype dtype_;
357
+
358
+ void eval(const std::vector<array>& inputs, array& out);
359
+ };
360
+
361
+ class AsStrided : public Primitive {
362
+ public:
363
+ explicit AsStrided(
364
+ Stream stream,
365
+ const std::vector<int>& shape,
366
+ const std::vector<size_t>& strides,
367
+ size_t offset)
368
+ : Primitive(stream), shape_(shape), strides_(strides), offset_(offset){};
369
+
370
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
371
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
372
+
373
+ DEFINE_GRADS()
374
+ DEFINE_PRINT(AsStrided)
375
+ bool is_equivalent(const Primitive& other) const override;
376
+
377
+ private:
378
+ std::vector<int> shape_;
379
+ std::vector<size_t> strides_;
380
+ size_t offset_;
381
+
382
+ void eval(const std::vector<array>& inputs, array& out);
383
+ };
384
+
385
+ class Broadcast : public Primitive {
386
+ public:
387
+ explicit Broadcast(Stream stream, const std::vector<int>& shape)
388
+ : Primitive(stream), shape_(shape){};
389
+
390
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
391
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
392
+
393
+ std::pair<array, int> vmap(
394
+ const std::vector<array>& inputs,
395
+ const std::vector<int>& axes) override;
396
+
397
+ DEFINE_GRADS()
398
+ DEFINE_PRINT(Broadcast)
399
+ bool is_equivalent(const Primitive& other) const override;
400
+
401
+ private:
402
+ std::vector<int> shape_;
403
+
404
+ void eval(const std::vector<array>& inputs, array& out);
405
+ };
406
+
407
+ class Ceil : public Primitive {
408
+ public:
409
+ explicit Ceil(Stream stream) : Primitive(stream){};
410
+
411
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
412
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
413
+
414
+ std::pair<array, int> vmap(
415
+ const std::vector<array>& inputs,
416
+ const std::vector<int>& axes) override;
417
+
418
+ DEFINE_GRADS()
419
+ DEFINE_PRINT(Ceil)
420
+ DEFINE_DEFAULT_IS_EQUIVALENT()
421
+
422
+ private:
423
+ void eval(const std::vector<array>& inputs, array& out);
424
+ };
425
+
426
+ class Concatenate : public Primitive {
427
+ public:
428
+ explicit Concatenate(Stream stream, int axis)
429
+ : Primitive(stream), axis_(axis){};
430
+
431
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
432
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
433
+
434
+ std::pair<array, int> vmap(
435
+ const std::vector<array>& inputs,
436
+ const std::vector<int>& axes) override;
437
+
438
+ DEFINE_GRADS()
439
+ DEFINE_PRINT(Concatenate)
440
+ bool is_equivalent(const Primitive& other) const override;
441
+
442
+ private:
443
+ int axis_;
444
+
445
+ void eval(const std::vector<array>& inputs, array& out);
446
+ };
447
+
448
+ class Convolution : public Primitive {
449
+ public:
450
+ explicit Convolution(
451
+ Stream stream,
452
+ const std::vector<int>& padding,
453
+ const std::vector<int>& kernel_strides,
454
+ const std::vector<int>& kernel_dilation,
455
+ const std::vector<int>& input_dilation)
456
+ : Primitive(stream),
457
+ padding_(padding),
458
+ kernel_strides_(kernel_strides),
459
+ kernel_dilation_(kernel_dilation),
460
+ input_dilation_(input_dilation){};
461
+
462
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
463
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
464
+
465
+ std::vector<array> vjp(
466
+ const std::vector<array>& primals,
467
+ const array& cotan,
468
+ const std::vector<int>& argnums) override;
469
+
470
+ DEFINE_PRINT(Convolution)
471
+ bool is_equivalent(const Primitive& other) const override;
472
+
473
+ private:
474
+ std::vector<int> padding_;
475
+ std::vector<int> kernel_strides_;
476
+ std::vector<int> kernel_dilation_;
477
+ std::vector<int> input_dilation_;
478
+
479
+ void eval(const std::vector<array>& inputs, array& out);
480
+ };
481
+
482
+ class Copy : public Primitive {
483
+ public:
484
+ explicit Copy(Stream stream) : Primitive(stream){};
485
+
486
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
487
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
488
+
489
+ std::pair<array, int> vmap(
490
+ const std::vector<array>& inputs,
491
+ const std::vector<int>& axes) override;
492
+
493
+ DEFINE_GRADS()
494
+ DEFINE_PRINT(Copy)
495
+ DEFINE_DEFAULT_IS_EQUIVALENT()
496
+
497
+ private:
498
+ void eval(const std::vector<array>& inputs, array& out);
499
+ };
500
+
501
+ class Cos : public Primitive {
502
+ public:
503
+ explicit Cos(Stream stream) : Primitive(stream){};
504
+
505
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
506
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
507
+
508
+ std::pair<array, int> vmap(
509
+ const std::vector<array>& inputs,
510
+ const std::vector<int>& axes) override;
511
+
512
+ DEFINE_GRADS()
513
+ DEFINE_PRINT(Cos)
514
+ DEFINE_DEFAULT_IS_EQUIVALENT()
515
+
516
+ private:
517
+ void eval(const std::vector<array>& inputs, array& out);
518
+ };
519
+
520
+ class Cosh : public Primitive {
521
+ public:
522
+ explicit Cosh(Stream stream) : Primitive(stream){};
523
+
524
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
525
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
526
+
527
+ std::pair<array, int> vmap(
528
+ const std::vector<array>& inputs,
529
+ const std::vector<int>& axes) override;
530
+
531
+ DEFINE_GRADS()
532
+ DEFINE_PRINT(Cosh)
533
+ DEFINE_DEFAULT_IS_EQUIVALENT()
534
+
535
+ private:
536
+ void eval(const std::vector<array>& inputs, array& out);
537
+ };
538
+
539
+ class Divide : public Primitive {
540
+ public:
541
+ explicit Divide(Stream stream) : Primitive(stream){};
542
+
543
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
544
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
545
+
546
+ std::pair<array, int> vmap(
547
+ const std::vector<array>& inputs,
548
+ const std::vector<int>& axes) override;
549
+
550
+ DEFINE_GRADS()
551
+ DEFINE_PRINT(Divide)
552
+ DEFINE_DEFAULT_IS_EQUIVALENT()
553
+
554
+ private:
555
+ void eval(const std::vector<array>& inputs, array& out);
556
+ };
557
+
558
+ class Remainder : public Primitive {
559
+ public:
560
+ explicit Remainder(Stream stream) : Primitive(stream){};
561
+
562
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
563
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
564
+
565
+ std::pair<array, int> vmap(
566
+ const std::vector<array>& inputs,
567
+ const std::vector<int>& axes) override;
568
+
569
+ DEFINE_GRADS()
570
+ DEFINE_PRINT(Remainder)
571
+ DEFINE_DEFAULT_IS_EQUIVALENT()
572
+
573
+ private:
574
+ void eval(const std::vector<array>& inputs, array& out);
575
+ };
576
+
577
+ class Equal : public Primitive {
578
+ public:
579
+ explicit Equal(Stream stream, bool equal_nan = false)
580
+ : Primitive(stream), equal_nan_(equal_nan){};
581
+
582
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
583
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
584
+
585
+ std::pair<array, int> vmap(
586
+ const std::vector<array>& inputs,
587
+ const std::vector<int>& axes) override;
588
+
589
+ DEFINE_GRADS()
590
+ DEFINE_PRINT(Equal)
591
+ DEFINE_DEFAULT_IS_EQUIVALENT()
592
+
593
+ private:
594
+ void eval(const std::vector<array>& inputs, array& out);
595
+ bool equal_nan_;
596
+ };
597
+
598
+ class Erf : public Primitive {
599
+ public:
600
+ explicit Erf(Stream stream) : Primitive(stream){};
601
+
602
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
603
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
604
+
605
+ std::pair<array, int> vmap(
606
+ const std::vector<array>& inputs,
607
+ const std::vector<int>& axes) override;
608
+
609
+ DEFINE_GRADS()
610
+ DEFINE_PRINT(Erf)
611
+ DEFINE_DEFAULT_IS_EQUIVALENT()
612
+
613
+ private:
614
+ void eval(const std::vector<array>& inputs, array& out);
615
+ };
616
+
617
+ class ErfInv : public Primitive {
618
+ public:
619
+ explicit ErfInv(Stream stream) : Primitive(stream){};
620
+
621
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
622
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
623
+
624
+ std::pair<array, int> vmap(
625
+ const std::vector<array>& inputs,
626
+ const std::vector<int>& axes) override;
627
+
628
+ DEFINE_GRADS()
629
+ DEFINE_PRINT(ErfInv)
630
+ DEFINE_DEFAULT_IS_EQUIVALENT()
631
+
632
+ private:
633
+ void eval(const std::vector<array>& inputs, array& out);
634
+ };
635
+
636
+ class Exp : public Primitive {
637
+ public:
638
+ explicit Exp(Stream stream) : Primitive(stream){};
639
+
640
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
641
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
642
+
643
+ std::pair<array, int> vmap(
644
+ const std::vector<array>& inputs,
645
+ const std::vector<int>& axes) override;
646
+
647
+ DEFINE_GRADS()
648
+ DEFINE_PRINT(Exp)
649
+ DEFINE_DEFAULT_IS_EQUIVALENT()
650
+
651
+ private:
652
+ void eval(const std::vector<array>& inputs, array& out);
653
+ };
654
+
655
+ class FFT : public Primitive {
656
+ public:
657
+ explicit FFT(
658
+ Stream stream,
659
+ const std::vector<size_t>& axes,
660
+ bool inverse,
661
+ bool real)
662
+ : Primitive(stream), axes_(axes), inverse_(inverse), real_(real){};
663
+
664
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
665
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
666
+
667
+ std::pair<array, int> vmap(
668
+ const std::vector<array>& inputs,
669
+ const std::vector<int>& axes) override;
670
+
671
+ DEFINE_GRADS()
672
+ DEFINE_PRINT(FFT)
673
+
674
+ bool is_equivalent(const Primitive& other) const override;
675
+
676
+ private:
677
+ std::vector<size_t> axes_;
678
+ bool inverse_;
679
+ bool real_;
680
+
681
+ void eval(const std::vector<array>& inputs, array& out);
682
+ };
683
+
684
+ class Floor : public Primitive {
685
+ public:
686
+ explicit Floor(Stream stream) : Primitive(stream){};
687
+
688
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
689
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
690
+
691
+ std::pair<array, int> vmap(
692
+ const std::vector<array>& inputs,
693
+ const std::vector<int>& axes) override;
694
+
695
+ DEFINE_GRADS()
696
+ DEFINE_PRINT(Floor)
697
+ DEFINE_DEFAULT_IS_EQUIVALENT()
698
+
699
+ private:
700
+ void eval(const std::vector<array>& inputs, array& out);
701
+ };
702
+
703
+ class Full : public Primitive {
704
+ public:
705
+ explicit Full(Stream stream) : Primitive(stream){};
706
+
707
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
708
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
709
+
710
+ std::pair<array, int> vmap(
711
+ const std::vector<array>& inputs,
712
+ const std::vector<int>& axes) override;
713
+
714
+ DEFINE_GRADS()
715
+ DEFINE_PRINT(Full)
716
+ DEFINE_DEFAULT_IS_EQUIVALENT()
717
+
718
+ private:
719
+ void eval(const std::vector<array>& inputs, array& out);
720
+ };
721
+
722
+ class Gather : public Primitive {
723
+ public:
724
+ explicit Gather(
725
+ Stream stream,
726
+ const std::vector<int>& axes,
727
+ const std::vector<int>& slice_sizes)
728
+ : Primitive(stream), axes_(axes), slice_sizes_(slice_sizes){};
729
+
730
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
731
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
732
+
733
+ std::pair<array, int> vmap(
734
+ const std::vector<array>& inputs,
735
+ const std::vector<int>& axes) override;
736
+
737
+ DEFINE_GRADS()
738
+ DEFINE_PRINT(Gather)
739
+ bool is_equivalent(const Primitive& other) const override;
740
+
741
+ private:
742
+ void eval(const std::vector<array>& inputs, array& out);
743
+ std::vector<int> axes_;
744
+ std::vector<int> slice_sizes_;
745
+ };
746
+
747
+ class Greater : public Primitive {
748
+ public:
749
+ explicit Greater(Stream stream) : Primitive(stream){};
750
+
751
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
752
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
753
+
754
+ std::pair<array, int> vmap(
755
+ const std::vector<array>& inputs,
756
+ const std::vector<int>& axes) override;
757
+
758
+ DEFINE_GRADS()
759
+ DEFINE_PRINT(Greater)
760
+ DEFINE_DEFAULT_IS_EQUIVALENT()
761
+
762
+ private:
763
+ void eval(const std::vector<array>& inputs, array& out);
764
+ };
765
+
766
+ class GreaterEqual : public Primitive {
767
+ public:
768
+ explicit GreaterEqual(Stream stream) : Primitive(stream){};
769
+
770
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
771
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
772
+
773
+ std::pair<array, int> vmap(
774
+ const std::vector<array>& inputs,
775
+ const std::vector<int>& axes) override;
776
+
777
+ DEFINE_GRADS()
778
+ DEFINE_PRINT(GreaterEqual)
779
+ DEFINE_DEFAULT_IS_EQUIVALENT()
780
+
781
+ private:
782
+ void eval(const std::vector<array>& inputs, array& out);
783
+ };
784
+
785
+ class Less : public Primitive {
786
+ public:
787
+ explicit Less(Stream stream) : Primitive(stream){};
788
+
789
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
790
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
791
+
792
+ std::pair<array, int> vmap(
793
+ const std::vector<array>& inputs,
794
+ const std::vector<int>& axes) override;
795
+
796
+ DEFINE_GRADS()
797
+ DEFINE_PRINT(Less)
798
+ DEFINE_DEFAULT_IS_EQUIVALENT()
799
+
800
+ private:
801
+ void eval(const std::vector<array>& inputs, array& out);
802
+ };
803
+
804
+ class LessEqual : public Primitive {
805
+ public:
806
+ explicit LessEqual(Stream stream) : Primitive(stream){};
807
+
808
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
809
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
810
+
811
+ std::pair<array, int> vmap(
812
+ const std::vector<array>& inputs,
813
+ const std::vector<int>& axes) override;
814
+
815
+ DEFINE_GRADS()
816
+ DEFINE_PRINT(LessEqual)
817
+ DEFINE_DEFAULT_IS_EQUIVALENT()
818
+
819
+ private:
820
+ void eval(const std::vector<array>& inputs, array& out);
821
+ };
822
+
823
+ class Load : public Primitive {
824
+ public:
825
+ explicit Load(
826
+ Stream stream,
827
+ std::shared_ptr<io::Reader> reader,
828
+ size_t offset,
829
+ bool swap_endianness = false)
830
+ : Primitive(stream),
831
+ reader_(reader),
832
+ offset_(offset),
833
+ swap_endianness_(swap_endianness){};
834
+
835
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
836
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
837
+
838
+ DEFINE_PRINT(Load)
839
+
840
+ private:
841
+ void eval(const std::vector<array>& inputs, array& out);
842
+ std::shared_ptr<io::Reader> reader_;
843
+ size_t offset_;
844
+ bool swap_endianness_;
845
+ };
846
+
847
+ class Log : public Primitive {
848
+ public:
849
+ enum Base { two, ten, e };
850
+
851
+ explicit Log(Stream stream, Base base) : Primitive(stream), base_(base){};
852
+
853
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
854
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
855
+
856
+ std::pair<array, int> vmap(
857
+ const std::vector<array>& inputs,
858
+ const std::vector<int>& axes) override;
859
+
860
+ DEFINE_GRADS()
861
+ DEFINE_PRINT(Log)
862
+ DEFINE_DEFAULT_IS_EQUIVALENT()
863
+
864
+ private:
865
+ Base base_;
866
+ void eval(const std::vector<array>& inputs, array& out);
867
+ };
868
+
869
+ class Log1p : public Primitive {
870
+ public:
871
+ explicit Log1p(Stream stream) : Primitive(stream){};
872
+
873
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
874
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
875
+
876
+ std::pair<array, int> vmap(
877
+ const std::vector<array>& inputs,
878
+ const std::vector<int>& axes) override;
879
+
880
+ DEFINE_GRADS()
881
+ DEFINE_PRINT(Log1p)
882
+
883
+ private:
884
+ void eval(const std::vector<array>& inputs, array& out);
885
+ };
886
+
887
+ class LogicalNot : public Primitive {
888
+ public:
889
+ explicit LogicalNot(Stream stream) : Primitive(stream){};
890
+
891
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
892
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
893
+
894
+ std::pair<array, int> vmap(
895
+ const std::vector<array>& inputs,
896
+ const std::vector<int>& axes) override;
897
+
898
+ DEFINE_GRADS()
899
+ DEFINE_PRINT(LogicalNot)
900
+ DEFINE_DEFAULT_IS_EQUIVALENT()
901
+
902
+ private:
903
+ void eval(const std::vector<array>& inputs, array& out);
904
+ };
905
+
906
+ class LogAddExp : public Primitive {
907
+ public:
908
+ explicit LogAddExp(Stream stream) : Primitive(stream){};
909
+
910
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
911
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
912
+
913
+ std::pair<array, int> vmap(
914
+ const std::vector<array>& inputs,
915
+ const std::vector<int>& axes) override;
916
+
917
+ DEFINE_GRADS()
918
+ DEFINE_PRINT(LogAddExp)
919
+ DEFINE_DEFAULT_IS_EQUIVALENT()
920
+
921
+ private:
922
+ void eval(const std::vector<array>& inputs, array& out);
923
+ };
924
+
925
+ class Matmul : public Primitive {
926
+ public:
927
+ explicit Matmul(Stream stream) : Primitive(stream){};
928
+
929
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
930
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
931
+
932
+ std::vector<array> vjp(
933
+ const std::vector<array>& primals,
934
+ const array& cotan,
935
+ const std::vector<int>& argnums) override;
936
+
937
+ std::pair<array, int> vmap(
938
+ const std::vector<array>& inputs,
939
+ const std::vector<int>& axes) override;
940
+
941
+ DEFINE_PRINT(Matmul)
942
+ DEFINE_DEFAULT_IS_EQUIVALENT()
943
+ };
944
+
945
+ class Maximum : public Primitive {
946
+ public:
947
+ explicit Maximum(Stream stream) : Primitive(stream){};
948
+
949
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
950
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
951
+
952
+ std::pair<array, int> vmap(
953
+ const std::vector<array>& inputs,
954
+ const std::vector<int>& axes) override;
955
+
956
+ DEFINE_GRADS()
957
+ DEFINE_PRINT(Maximum)
958
+ DEFINE_DEFAULT_IS_EQUIVALENT()
959
+
960
+ private:
961
+ void eval(const std::vector<array>& inputs, array& out);
962
+ };
963
+
964
+ class Minimum : public Primitive {
965
+ public:
966
+ explicit Minimum(Stream stream) : Primitive(stream){};
967
+
968
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
969
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
970
+
971
+ std::pair<array, int> vmap(
972
+ const std::vector<array>& inputs,
973
+ const std::vector<int>& axes) override;
974
+
975
+ DEFINE_GRADS()
976
+ DEFINE_PRINT(Minimum)
977
+ DEFINE_DEFAULT_IS_EQUIVALENT()
978
+
979
+ private:
980
+ void eval(const std::vector<array>& inputs, array& out);
981
+ };
982
+
983
+ class Multiply : public Primitive {
984
+ public:
985
+ explicit Multiply(Stream stream) : Primitive(stream){};
986
+
987
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
988
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
989
+
990
+ std::pair<array, int> vmap(
991
+ const std::vector<array>& inputs,
992
+ const std::vector<int>& axes) override;
993
+
994
+ DEFINE_GRADS()
995
+ DEFINE_PRINT(Multiply)
996
+ DEFINE_DEFAULT_IS_EQUIVALENT()
997
+
998
+ private:
999
+ void eval(const std::vector<array>& inputs, array& out);
1000
+ };
1001
+
1002
+ class Negative : public Primitive {
1003
+ public:
1004
+ explicit Negative(Stream stream) : Primitive(stream){};
1005
+
1006
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1007
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1008
+
1009
+ std::pair<array, int> vmap(
1010
+ const std::vector<array>& inputs,
1011
+ const std::vector<int>& axes) override;
1012
+
1013
+ DEFINE_GRADS()
1014
+ DEFINE_PRINT(Negative)
1015
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1016
+
1017
+ private:
1018
+ void eval(const std::vector<array>& inputs, array& out);
1019
+ };
1020
+
1021
+ class NotEqual : public Primitive {
1022
+ public:
1023
+ explicit NotEqual(Stream stream) : Primitive(stream){};
1024
+
1025
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1026
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1027
+
1028
+ std::pair<array, int> vmap(
1029
+ const std::vector<array>& inputs,
1030
+ const std::vector<int>& axes) override;
1031
+
1032
+ DEFINE_GRADS()
1033
+ DEFINE_PRINT(NotEqual)
1034
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1035
+
1036
+ private:
1037
+ void eval(const std::vector<array>& inputs, array& out);
1038
+ };
1039
+
1040
+ class Pad : public Primitive {
1041
+ public:
1042
+ explicit Pad(
1043
+ Stream stream,
1044
+ const std::vector<int>& axes,
1045
+ const std::vector<int>& low_pad_size,
1046
+ const std::vector<int>& high_pad_size)
1047
+ : Primitive(stream),
1048
+ axes_(axes),
1049
+ low_pad_size_(low_pad_size),
1050
+ high_pad_size_(high_pad_size){};
1051
+
1052
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1053
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1054
+
1055
+ std::pair<array, int> vmap(
1056
+ const std::vector<array>& inputs,
1057
+ const std::vector<int>& axes) override;
1058
+
1059
+ DEFINE_GRADS()
1060
+ DEFINE_PRINT(Pad)
1061
+ bool is_equivalent(const Primitive& other) const override;
1062
+
1063
+ private:
1064
+ std::vector<int> axes_;
1065
+ std::vector<int> low_pad_size_;
1066
+ std::vector<int> high_pad_size_;
1067
+
1068
+ void eval(const std::vector<array>& inputs, array& out);
1069
+ };
1070
+
1071
+ class Partition : public Primitive {
1072
+ public:
1073
+ explicit Partition(Stream stream, int kth, int axis)
1074
+ : Primitive(stream), kth_(kth), axis_(axis){};
1075
+
1076
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1077
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1078
+
1079
+ std::pair<array, int> vmap(
1080
+ const std::vector<array>& inputs,
1081
+ const std::vector<int>& axes) override;
1082
+
1083
+ DEFINE_GRADS()
1084
+ DEFINE_PRINT(Partition)
1085
+ bool is_equivalent(const Primitive& other) const override;
1086
+
1087
+ private:
1088
+ int kth_;
1089
+ int axis_;
1090
+
1091
+ void eval(const std::vector<array>& inputs, array& out);
1092
+ };
1093
+
1094
+ class Power : public Primitive {
1095
+ public:
1096
+ explicit Power(Stream stream) : Primitive(stream){};
1097
+
1098
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1099
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1100
+
1101
+ std::pair<array, int> vmap(
1102
+ const std::vector<array>& inputs,
1103
+ const std::vector<int>& axes) override;
1104
+
1105
+ DEFINE_GRADS()
1106
+ DEFINE_PRINT(Power)
1107
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1108
+
1109
+ private:
1110
+ void eval(const std::vector<array>& inputs, array& out);
1111
+ };
1112
+
1113
+ class QuantizedMatmul : public Primitive {
1114
+ public:
1115
+ explicit QuantizedMatmul(
1116
+ Stream stream,
1117
+ int group_size,
1118
+ int bits,
1119
+ bool transpose)
1120
+ : Primitive(stream),
1121
+ group_size_(group_size),
1122
+ bits_(bits),
1123
+ transpose_(transpose){};
1124
+
1125
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1126
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1127
+
1128
+ std::pair<array, int> vmap(
1129
+ const std::vector<array>& inputs,
1130
+ const std::vector<int>& axes) override;
1131
+
1132
+ DEFINE_GRADS()
1133
+ DEFINE_PRINT(QuantizedMatmul)
1134
+ bool is_equivalent(const Primitive& other) const override;
1135
+
1136
+ private:
1137
+ int group_size_;
1138
+ int bits_;
1139
+ bool transpose_;
1140
+
1141
+ void eval(const std::vector<array>& inputs, array& out);
1142
+ };
1143
+
1144
+ class RandomBits : public Primitive {
1145
+ public:
1146
+ explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)
1147
+ : Primitive(stream), shape_(shape), width_(width){};
1148
+
1149
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1150
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1151
+
1152
+ std::pair<array, int> vmap(
1153
+ const std::vector<array>& inputs,
1154
+ const std::vector<int>& axes) override;
1155
+
1156
+ DEFINE_PRINT(RandomBits)
1157
+ bool is_equivalent(const Primitive& other) const override;
1158
+
1159
+ private:
1160
+ std::vector<int> shape_;
1161
+ int width_;
1162
+
1163
+ void eval(const std::vector<array>& inputs, array& out);
1164
+ };
1165
+
1166
+ class Reshape : public Primitive {
1167
+ public:
1168
+ explicit Reshape(Stream stream, const std::vector<int>& shape)
1169
+ : Primitive(stream), shape_(shape){};
1170
+
1171
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1172
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1173
+
1174
+ std::pair<array, int> vmap(
1175
+ const std::vector<array>& inputs,
1176
+ const std::vector<int>& axes) override;
1177
+
1178
+ DEFINE_GRADS()
1179
+ DEFINE_PRINT(Reshape)
1180
+ bool is_equivalent(const Primitive& other) const override;
1181
+
1182
+ private:
1183
+ std::vector<int> shape_;
1184
+
1185
+ void eval(const std::vector<array>& inputs, array& out);
1186
+ };
1187
+
1188
+ class Reduce : public Primitive {
1189
+ public:
1190
+ enum ReduceType { And, Or, Sum, Prod, Min, Max };
1191
+
1192
+ explicit Reduce(
1193
+ Stream stream,
1194
+ ReduceType reduce_type,
1195
+ const std::vector<int>& axes)
1196
+ : Primitive(stream), reduce_type_(reduce_type), axes_(axes){};
1197
+
1198
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1199
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1200
+
1201
+ std::pair<array, int> vmap(
1202
+ const std::vector<array>& inputs,
1203
+ const std::vector<int>& axes) override;
1204
+ std::vector<array> vjp(
1205
+ const std::vector<array>& primals,
1206
+ const array& cotan,
1207
+ const std::vector<int>& argnums) override;
1208
+
1209
+ void print(std::ostream& os) override {
1210
+ switch (reduce_type_) {
1211
+ case And:
1212
+ os << "And";
1213
+ case Or:
1214
+ os << "And";
1215
+ break;
1216
+ case Sum:
1217
+ os << "Sum";
1218
+ break;
1219
+ case Prod:
1220
+ os << "Prod";
1221
+ break;
1222
+ case Min:
1223
+ os << "Min";
1224
+ break;
1225
+ case Max:
1226
+ os << "Max";
1227
+ break;
1228
+ }
1229
+ os << " Reduce";
1230
+ }
1231
+ bool is_equivalent(const Primitive& other) const override;
1232
+
1233
+ private:
1234
+ ReduceType reduce_type_;
1235
+ std::vector<int> axes_;
1236
+
1237
+ void eval(const std::vector<array>& inputs, array& out);
1238
+ };
1239
+
1240
+ class Round : public Primitive {
1241
+ public:
1242
+ explicit Round(Stream stream) : Primitive(stream){};
1243
+
1244
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1245
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1246
+
1247
+ std::pair<array, int> vmap(
1248
+ const std::vector<array>& inputs,
1249
+ const std::vector<int>& axes) override;
1250
+
1251
+ DEFINE_GRADS()
1252
+ DEFINE_PRINT(Round)
1253
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1254
+
1255
+ private:
1256
+ void eval(const std::vector<array>& inputs, array& out);
1257
+ };
1258
+
1259
+ class Scan : public Primitive {
1260
+ public:
1261
+ enum ReduceType { Max, Min, Sum, Prod };
1262
+
1263
+ explicit Scan(
1264
+ Stream stream,
1265
+ ReduceType reduce_type,
1266
+ int axis,
1267
+ bool reverse,
1268
+ bool inclusive)
1269
+ : Primitive(stream),
1270
+ reduce_type_(reduce_type),
1271
+ axis_(axis),
1272
+ reverse_(reverse),
1273
+ inclusive_(inclusive){};
1274
+
1275
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1276
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1277
+
1278
+ std::pair<array, int> vmap(
1279
+ const std::vector<array>& inputs,
1280
+ const std::vector<int>& axes) override;
1281
+
1282
+ DEFINE_GRADS();
1283
+ void print(std::ostream& os) override {
1284
+ os << "Cum";
1285
+ switch (reduce_type_) {
1286
+ case Sum:
1287
+ os << "Sum";
1288
+ break;
1289
+ case Prod:
1290
+ os << "Prod";
1291
+ break;
1292
+ case Min:
1293
+ os << "Min";
1294
+ break;
1295
+ case Max:
1296
+ os << "Max";
1297
+ break;
1298
+ }
1299
+ os << " Reduce";
1300
+ }
1301
+ bool is_equivalent(const Primitive& other) const override;
1302
+
1303
+ private:
1304
+ ReduceType reduce_type_;
1305
+ int axis_;
1306
+ bool reverse_;
1307
+ bool inclusive_;
1308
+
1309
+ void eval(const std::vector<array>& inputs, array& out);
1310
+ };
1311
+
1312
+ class Scatter : public Primitive {
1313
+ public:
1314
+ enum ReduceType { Max, Min, Sum, Prod, None };
1315
+
1316
+ explicit Scatter(
1317
+ Stream stream,
1318
+ ReduceType reduce_type,
1319
+ const std::vector<int>& axes)
1320
+ : Primitive(stream), reduce_type_(reduce_type), axes_(axes){};
1321
+
1322
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1323
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1324
+
1325
+ DEFINE_PRINT(Scatter)
1326
+ bool is_equivalent(const Primitive& other) const override;
1327
+
1328
+ private:
1329
+ void eval(const std::vector<array>& inputs, array& out);
1330
+ ReduceType reduce_type_;
1331
+ std::vector<int> axes_;
1332
+ };
1333
+
1334
+ class Sigmoid : public Primitive {
1335
+ public:
1336
+ explicit Sigmoid(Stream stream) : Primitive(stream){};
1337
+
1338
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1339
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1340
+
1341
+ std::pair<array, int> vmap(
1342
+ const std::vector<array>& inputs,
1343
+ const std::vector<int>& axes) override;
1344
+
1345
+ DEFINE_GRADS()
1346
+ DEFINE_PRINT(Sigmoid)
1347
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1348
+
1349
+ private:
1350
+ void eval(const std::vector<array>& inputs, array& out);
1351
+ };
1352
+
1353
+ class Sign : public Primitive {
1354
+ public:
1355
+ explicit Sign(Stream stream) : Primitive(stream){};
1356
+
1357
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1358
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1359
+
1360
+ std::pair<array, int> vmap(
1361
+ const std::vector<array>& inputs,
1362
+ const std::vector<int>& axes) override;
1363
+
1364
+ DEFINE_GRADS()
1365
+ DEFINE_PRINT(Sign)
1366
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1367
+
1368
+ private:
1369
+ void eval(const std::vector<array>& inputs, array& out);
1370
+ };
1371
+
1372
+ class Sin : public Primitive {
1373
+ public:
1374
+ explicit Sin(Stream stream) : Primitive(stream){};
1375
+
1376
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1377
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1378
+
1379
+ std::pair<array, int> vmap(
1380
+ const std::vector<array>& inputs,
1381
+ const std::vector<int>& axes) override;
1382
+
1383
+ DEFINE_GRADS()
1384
+ DEFINE_PRINT(Sin)
1385
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1386
+
1387
+ private:
1388
+ void eval(const std::vector<array>& inputs, array& out);
1389
+ };
1390
+
1391
+ class Sinh : public Primitive {
1392
+ public:
1393
+ explicit Sinh(Stream stream) : Primitive(stream){};
1394
+
1395
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1396
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1397
+
1398
+ std::pair<array, int> vmap(
1399
+ const std::vector<array>& inputs,
1400
+ const std::vector<int>& axes) override;
1401
+
1402
+ DEFINE_GRADS()
1403
+ DEFINE_PRINT(Sinh)
1404
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1405
+
1406
+ private:
1407
+ void eval(const std::vector<array>& inputs, array& out);
1408
+ };
1409
+
1410
+ class Slice : public Primitive {
1411
+ public:
1412
+ explicit Slice(
1413
+ Stream stream,
1414
+ const std::vector<int>& start_indices,
1415
+ const std::vector<int>& end_indices,
1416
+ const std::vector<int>& strides)
1417
+ : Primitive(stream),
1418
+ start_indices_(start_indices),
1419
+ end_indices_(end_indices),
1420
+ strides_(strides){};
1421
+
1422
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1423
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1424
+
1425
+ std::pair<array, int> vmap(
1426
+ const std::vector<array>& inputs,
1427
+ const std::vector<int>& axes) override;
1428
+
1429
+ DEFINE_GRADS()
1430
+ DEFINE_PRINT(Slice)
1431
+ bool is_equivalent(const Primitive& other) const override;
1432
+
1433
+ private:
1434
+ std::vector<int> start_indices_;
1435
+ std::vector<int> end_indices_;
1436
+ std::vector<int> strides_;
1437
+
1438
+ void eval(const std::vector<array>& inputs, array& out);
1439
+ };
1440
+
1441
+ class Softmax : public Primitive {
1442
+ public:
1443
+ explicit Softmax(Stream stream) : Primitive(stream){};
1444
+
1445
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1446
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1447
+
1448
+ std::pair<array, int> vmap(
1449
+ const std::vector<array>& inputs,
1450
+ const std::vector<int>& axes) override;
1451
+
1452
+ DEFINE_GRADS()
1453
+ DEFINE_PRINT(Softmax)
1454
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1455
+
1456
+ private:
1457
+ void eval(const std::vector<array>& inputs, array& out);
1458
+ };
1459
+
1460
+ class Sort : public Primitive {
1461
+ public:
1462
+ explicit Sort(Stream stream, int axis) : Primitive(stream), axis_(axis){};
1463
+
1464
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1465
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1466
+
1467
+ std::pair<array, int> vmap(
1468
+ const std::vector<array>& inputs,
1469
+ const std::vector<int>& axes) override;
1470
+
1471
+ DEFINE_GRADS()
1472
+ DEFINE_PRINT(Sort)
1473
+ bool is_equivalent(const Primitive& other) const override;
1474
+
1475
+ private:
1476
+ int axis_;
1477
+
1478
+ void eval(const std::vector<array>& inputs, array& out);
1479
+ };
1480
+
1481
+ class Square : public Primitive {
1482
+ public:
1483
+ explicit Square(Stream stream) : Primitive(stream){};
1484
+
1485
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1486
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1487
+
1488
+ std::pair<array, int> vmap(
1489
+ const std::vector<array>& inputs,
1490
+ const std::vector<int>& axes) override;
1491
+
1492
+ DEFINE_GRADS()
1493
+ DEFINE_PRINT(Square)
1494
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1495
+
1496
+ private:
1497
+ void eval(const std::vector<array>& inputs, array& out);
1498
+ };
1499
+
1500
+ class Sqrt : public Primitive {
1501
+ public:
1502
+ explicit Sqrt(Stream stream, bool recip = false)
1503
+ : Primitive(stream), recip_(recip){};
1504
+
1505
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1506
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1507
+
1508
+ std::pair<array, int> vmap(
1509
+ const std::vector<array>& inputs,
1510
+ const std::vector<int>& axes) override;
1511
+
1512
+ DEFINE_GRADS()
1513
+ DEFINE_PRINT(Sqrt)
1514
+ bool is_equivalent(const Primitive& other) const override;
1515
+
1516
+ private:
1517
+ void eval(const std::vector<array>& inputs, array& out);
1518
+ bool recip_;
1519
+ };
1520
+
1521
+ class StopGradient : public Primitive {
1522
+ public:
1523
+ explicit StopGradient(Stream stream) : Primitive(stream){};
1524
+
1525
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1526
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1527
+
1528
+ std::pair<array, int> vmap(
1529
+ const std::vector<array>& inputs,
1530
+ const std::vector<int>& axes) override;
1531
+
1532
+ DEFINE_PRINT(StopGradient)
1533
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1534
+
1535
+ private:
1536
+ void eval(const std::vector<array>& inputs, array& out);
1537
+ };
1538
+
1539
+ class Subtract : public Primitive {
1540
+ public:
1541
+ explicit Subtract(Stream stream) : Primitive(stream){};
1542
+
1543
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1544
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1545
+
1546
+ std::pair<array, int> vmap(
1547
+ const std::vector<array>& inputs,
1548
+ const std::vector<int>& axes) override;
1549
+
1550
+ DEFINE_GRADS()
1551
+ DEFINE_PRINT(Subtract)
1552
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1553
+
1554
+ private:
1555
+ void eval(const std::vector<array>& inputs, array& out);
1556
+ };
1557
+
1558
+ class Tan : public Primitive {
1559
+ public:
1560
+ explicit Tan(Stream stream) : Primitive(stream){};
1561
+
1562
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1563
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1564
+
1565
+ std::pair<array, int> vmap(
1566
+ const std::vector<array>& inputs,
1567
+ const std::vector<int>& axes) override;
1568
+
1569
+ DEFINE_GRADS()
1570
+ DEFINE_PRINT(Tan)
1571
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1572
+
1573
+ private:
1574
+ void eval(const std::vector<array>& inputs, array& out);
1575
+ };
1576
+
1577
+ class Tanh : public Primitive {
1578
+ public:
1579
+ explicit Tanh(Stream stream) : Primitive(stream){};
1580
+
1581
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1582
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1583
+
1584
+ std::pair<array, int> vmap(
1585
+ const std::vector<array>& inputs,
1586
+ const std::vector<int>& axes) override;
1587
+
1588
+ DEFINE_GRADS()
1589
+ DEFINE_PRINT(Tanh)
1590
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1591
+
1592
+ private:
1593
+ void eval(const std::vector<array>& inputs, array& out);
1594
+ };
1595
+
1596
+ class Uniform : public Primitive {
1597
+ public:
1598
+ explicit Uniform(Stream stream) : Primitive(stream){};
1599
+
1600
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1601
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1602
+
1603
+ std::pair<array, int> vmap(
1604
+ const std::vector<array>& inputs,
1605
+ const std::vector<int>& axes) override;
1606
+
1607
+ DEFINE_PRINT(Uniform)
1608
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1609
+
1610
+ private:
1611
+ void eval(const std::vector<array>& inputs, array& out);
1612
+ };
1613
+
1614
+ class Transpose : public Primitive {
1615
+ public:
1616
+ explicit Transpose(Stream stream, const std::vector<int>& axes)
1617
+ : Primitive(stream), axes_(axes){};
1618
+
1619
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1620
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1621
+
1622
+ std::pair<array, int> vmap(
1623
+ const std::vector<array>& inputs,
1624
+ const std::vector<int>& axes) override;
1625
+
1626
+ DEFINE_GRADS()
1627
+ DEFINE_PRINT(Transpose)
1628
+ bool is_equivalent(const Primitive& other) const override;
1629
+
1630
+ private:
1631
+ std::vector<int> axes_;
1632
+
1633
+ void eval(const std::vector<array>& inputs, array& out);
1634
+ };
1635
+
1636
+ } // namespace mlx::core
lib/python3.11/site-packages/mlx/include/mlx/random.h ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <optional>
6
+
7
+ #include "mlx/array.h"
8
+ #include "mlx/stream.h"
9
+
10
+ namespace mlx::core::random {
11
+
12
+ class KeySequence {
13
+ public:
14
+ explicit KeySequence(uint64_t seed);
15
+
16
+ void seed(uint64_t seed);
17
+ array next();
18
+
19
+ // static default
20
+ static KeySequence& default_() {
21
+ static KeySequence ks(0);
22
+ return ks;
23
+ }
24
+
25
+ private:
26
+ array key_;
27
+ };
28
+
29
+ /** Get a PRNG key from a seed. */
30
+ array key(uint64_t seed);
31
+
32
+ /** Seed the default PRNG key. */
33
+ void seed(uint64_t seed);
34
+
35
+ /** Generate an array with type uint32 filled with random bits. */
36
+ array bits(
37
+ const std::vector<int>& shape,
38
+ int width,
39
+ const std::optional<array>& key = std::nullopt,
40
+ StreamOrDevice s = {});
41
+ inline array bits(
42
+ const std::vector<int>& shape,
43
+ const std::optional<array>& key = std::nullopt,
44
+ StreamOrDevice s = {}) {
45
+ return bits(shape, 4, key, s);
46
+ }
47
+
48
+ /** Split the rng key into a pair of keys. */
49
+ std::pair<array, array> split(const array& key, StreamOrDevice s = {});
50
+
51
+ /** Split the rng key into `num` keys. */
52
+ array split(const array& key, int num, StreamOrDevice s = {});
53
+
54
+ /** Generate uniform random numbers between low and high. */
55
+ array uniform(
56
+ const array& low,
57
+ const array& high,
58
+ const std::vector<int>& shape,
59
+ Dtype dtype = float32,
60
+ const std::optional<array>& key = std::nullopt,
61
+ StreamOrDevice s = {});
62
+
63
+ template <typename T, typename U>
64
+ array uniform(
65
+ T low,
66
+ U high,
67
+ const std::vector<int>& shape,
68
+ Dtype dtype = float32,
69
+ const std::optional<array>& key = std::nullopt,
70
+ StreamOrDevice s = {}) {
71
+ return uniform(array(low), array(high), shape, dtype, key, to_stream(s));
72
+ }
73
+
74
+ /** Generate uniform random numbers between 0 and 1. */
75
+ array uniform(
76
+ const std::vector<int>& shape,
77
+ Dtype dtype,
78
+ const std::optional<array>& key = std::nullopt,
79
+ StreamOrDevice s = {});
80
+ inline array uniform(
81
+ const std::vector<int>& shape,
82
+ const std::optional<array>& key = std::nullopt,
83
+ StreamOrDevice s = {}) {
84
+ return uniform(shape, float32, key);
85
+ }
86
+
87
+ /** Generate samples from the standard normal distribution. */
88
+ array normal(
89
+ const std::vector<int>& shape,
90
+ Dtype dtype,
91
+ const std::optional<array>& key = std::nullopt,
92
+ StreamOrDevice s = {});
93
+ inline array normal(
94
+ const std::vector<int>& shape,
95
+ const std::optional<array>& key = std::nullopt,
96
+ StreamOrDevice s = {}) {
97
+ return normal(shape, float32, key, s);
98
+ }
99
+
100
+ /** Generate integer samples uniformly at random */
101
+ array randint(
102
+ const array& low,
103
+ const array& high,
104
+ const std::vector<int>& shape,
105
+ Dtype dtype = int32,
106
+ const std::optional<array>& key = std::nullopt,
107
+ StreamOrDevice s = {});
108
+
109
+ template <typename T, typename U>
110
+ array randint(
111
+ T low,
112
+ U high,
113
+ const std::vector<int>& shape,
114
+ Dtype dtype = int32,
115
+ const std::optional<array>& key = std::nullopt,
116
+ StreamOrDevice s = {}) {
117
+ return randint(array(low), array(high), shape, dtype, key, to_stream(s));
118
+ };
119
+
120
+ /** Generate binary variables with probability to be true equal to p */
121
+ array bernoulli(
122
+ const array& p,
123
+ const std::vector<int>& shape,
124
+ const std::optional<array>& key = std::nullopt,
125
+ StreamOrDevice s = {});
126
+ array bernoulli(
127
+ const array& p,
128
+ const std::optional<array>& key = std::nullopt,
129
+ StreamOrDevice s = {});
130
+
131
+ template <typename T>
132
+ array bernoulli(
133
+ T p,
134
+ const std::optional<array>& key = std::nullopt,
135
+ StreamOrDevice s = {}) {
136
+ return bernoulli(array(p), key, s);
137
+ };
138
+
139
+ template <typename T>
140
+ array bernoulli(
141
+ T p,
142
+ const std::vector<int>& shape,
143
+ const std::optional<array>& key = std::nullopt,
144
+ StreamOrDevice s = {}) {
145
+ return bernoulli(array(p), shape, key, s);
146
+ };
147
+
148
+ array bernoulli(
149
+ const std::optional<array>& key = std::nullopt,
150
+ StreamOrDevice s = {});
151
+
152
+ array truncated_normal(
153
+ const array& lower,
154
+ const array& upper,
155
+ const std::vector<int>& shape,
156
+ Dtype dtype = float32,
157
+ const std::optional<array>& key = std::nullopt,
158
+ StreamOrDevice s = {});
159
+
160
+ array truncated_normal(
161
+ const array& lower,
162
+ const array& upper,
163
+ Dtype dtype = float32,
164
+ const std::optional<array>& key = std::nullopt,
165
+ StreamOrDevice s = {});
166
+
167
+ array gumbel(
168
+ const std::vector<int>& shape,
169
+ Dtype dtype = float32,
170
+ const std::optional<array>& key = std::nullopt,
171
+ StreamOrDevice s = {});
172
+
173
+ array categorical(
174
+ const array& logits,
175
+ int axis,
176
+ const std::vector<int>& shape,
177
+ const std::optional<array>& key = std::nullopt,
178
+ StreamOrDevice s = {});
179
+
180
+ array categorical(
181
+ const array& logits_,
182
+ int axis,
183
+ int num_samples,
184
+ const std::optional<array>& key = std::nullopt,
185
+ StreamOrDevice s = {});
186
+
187
+ array categorical(
188
+ const array& logits,
189
+ int axis = -1,
190
+ const std::optional<array>& key = std::nullopt,
191
+ StreamOrDevice s = {});
192
+
193
+ } // namespace mlx::core::random
lib/python3.11/site-packages/mlx/include/mlx/scheduler.h ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <atomic>
6
+ #include <future>
7
+ #include <queue>
8
+ #include <thread>
9
+ #include <unordered_map>
10
+
11
+ #include "mlx/backend/metal/metal.h"
12
+ #include "mlx/device.h"
13
+ #include "mlx/stream.h"
14
+
15
+ namespace mlx::core::scheduler {
16
+
17
+ struct StreamThread {
18
+ std::mutex mtx;
19
+ std::queue<std::function<void()>> q;
20
+ std::condition_variable cond;
21
+ bool stop;
22
+ Stream stream;
23
+ std::thread thread;
24
+
25
+ StreamThread(Stream stream)
26
+ : stop(false), stream(stream), thread(&StreamThread::thread_fn, this) {}
27
+
28
+ ~StreamThread() {
29
+ {
30
+ std::unique_lock<std::mutex> lk(mtx);
31
+ stop = true;
32
+ }
33
+ cond.notify_one();
34
+ thread.join();
35
+ }
36
+
37
+ void thread_fn() {
38
+ auto thread_pool = metal::new_scoped_memory_pool();
39
+ metal::new_stream(stream);
40
+ while (true) {
41
+ std::function<void()> task;
42
+ {
43
+ std::unique_lock<std::mutex> lk(mtx);
44
+ cond.wait(lk, [this] { return !this->q.empty() || this->stop; });
45
+ if (q.empty() && stop) {
46
+ return;
47
+ }
48
+ task = std::move(q.front());
49
+ q.pop();
50
+ }
51
+ task();
52
+ }
53
+ }
54
+
55
+ template <typename F>
56
+ void enqueue(F&& f) {
57
+ {
58
+ std::unique_lock<std::mutex> lk(mtx);
59
+ if (stop) {
60
+ throw std::runtime_error(
61
+ "Cannot enqueue work after stream is stopped.");
62
+ }
63
+ q.emplace(std::forward<F>(f));
64
+ }
65
+ cond.notify_one();
66
+ }
67
+ };
68
+
69
+ class Scheduler {
70
+ public:
71
+ Scheduler() : n_active_tasks_(0) {
72
+ if (metal::is_available()) {
73
+ default_streams_.insert({Device::gpu, new_stream(Device::gpu)});
74
+ }
75
+ default_streams_.insert({Device::cpu, new_stream(Device::cpu)});
76
+ }
77
+
78
+ // Not copyable or moveable
79
+ Scheduler(const Scheduler&) = delete;
80
+ Scheduler(Scheduler&&) = delete;
81
+ Scheduler& operator=(const Scheduler&) = delete;
82
+ Scheduler& operator=(Scheduler&&) = delete;
83
+
84
+ Stream new_stream(const Device& d) {
85
+ auto stream = Stream(streams_.size(), d);
86
+ streams_.push_back(new StreamThread{stream});
87
+ return stream;
88
+ }
89
+
90
+ template <typename F>
91
+ void enqueue(const Stream& stream, F&& f);
92
+
93
+ Stream get_default_stream(const Device& d) {
94
+ return default_streams_.at(d.type);
95
+ }
96
+
97
+ void set_default_stream(const Stream& s) {
98
+ default_streams_.at(s.device.type) = s;
99
+ }
100
+
101
+ void notify_new_task(const Stream& stream) {
102
+ {
103
+ std::unique_lock<std::mutex> lk(mtx);
104
+ n_active_tasks_++;
105
+ }
106
+ completion_cv.notify_all();
107
+ }
108
+
109
+ void notify_task_completion(const Stream& stream) {
110
+ {
111
+ std::unique_lock<std::mutex> lk(mtx);
112
+ n_active_tasks_--;
113
+ }
114
+ completion_cv.notify_all();
115
+ }
116
+
117
+ int n_active_tasks() const {
118
+ return n_active_tasks_;
119
+ }
120
+
121
+ void wait_for_one() {
122
+ std::unique_lock<std::mutex> lk(mtx);
123
+ int n_tasks_old = n_active_tasks();
124
+ if (n_tasks_old > 1) {
125
+ completion_cv.wait(lk, [this, n_tasks_old] {
126
+ return this->n_active_tasks() != n_tasks_old;
127
+ });
128
+ }
129
+ }
130
+
131
+ ~Scheduler() {
132
+ for (auto s : streams_) {
133
+ delete s;
134
+ }
135
+ }
136
+
137
+ private:
138
+ int n_active_tasks_;
139
+ std::vector<StreamThread*> streams_;
140
+ std::unordered_map<Device::DeviceType, Stream> default_streams_;
141
+ std::condition_variable completion_cv;
142
+ std::mutex mtx;
143
+ };
144
+
145
+ template <typename F>
146
+ void Scheduler::enqueue(const Stream& stream, F&& f) {
147
+ streams_[stream.index]->enqueue(std::forward<F>(f));
148
+ }
149
+
150
+ Scheduler& scheduler();
151
+
152
+ template <typename F>
153
+ void enqueue(const Stream& stream, F&& f) {
154
+ scheduler().enqueue(stream, std::forward<F>(f));
155
+ }
156
+
157
+ inline int n_active_tasks() {
158
+ return scheduler().n_active_tasks();
159
+ }
160
+
161
+ inline void notify_new_task(const Stream& stream) {
162
+ scheduler().notify_new_task(stream);
163
+ }
164
+
165
+ inline void notify_task_completion(const Stream& stream) {
166
+ scheduler().notify_task_completion(stream);
167
+ }
168
+
169
+ inline void wait_for_one() {
170
+ scheduler().wait_for_one();
171
+ }
172
+
173
+ } // namespace mlx::core::scheduler
lib/python3.11/site-packages/mlx/include/mlx/stream.h ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/device.h"
6
+
7
+ namespace mlx::core {
8
+
9
+ struct Stream {
10
+ int index;
11
+ Device device;
12
+ explicit Stream(int index, Device device) : index(index), device(device) {}
13
+ };
14
+
15
+ /** Get the default stream for the given device. */
16
+ Stream default_stream(Device d);
17
+
18
+ /** Make the stream the default for its device. */
19
+ void set_default_stream(Stream s);
20
+
21
+ /** Make a new stream on the given device. */
22
+ Stream new_stream(Device d);
23
+
24
+ inline bool operator==(const Stream& lhs, const Stream& rhs) {
25
+ return lhs.index == rhs.index;
26
+ }
27
+
28
+ inline bool operator!=(const Stream& lhs, const Stream& rhs) {
29
+ return !(lhs == rhs);
30
+ }
31
+
32
+ } // namespace mlx::core
lib/python3.11/site-packages/mlx/include/mlx/transforms.h ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "array.h"
6
+
7
+ namespace mlx::core {
8
+
9
+ /** Fuse equivalent arrays to avoid duplicate execution. */
10
+ void simplify(const std::vector<array>& outputs);
11
+
12
+ template <typename... Arrays>
13
+ void simplify(Arrays... outputs) {
14
+ simplify(std::vector<array>{std::forward<Arrays>(outputs)...});
15
+ }
16
+
17
+ void eval(const std::vector<array>& outputs, bool retain_graph = false);
18
+
19
+ template <typename... Arrays>
20
+ void eval(Arrays... outputs) {
21
+ eval(std::vector<array>{std::forward<Arrays>(outputs)...}, false);
22
+ }
23
+
24
+ /**
25
+ * Computes the output and vector-Jacobian product (VJP) of a function.
26
+ *
27
+ * Computes the vector-Jacobian product of the vector of cotangents with the
28
+ * Jacobian of the function evaluated at the primals. Returns a pair of
29
+ * vectors of output arrays and VJP arrays.
30
+ **/
31
+ std::pair<std::vector<array>, std::vector<array>> vjp(
32
+ const std::function<std::vector<array>(const std::vector<array>&)>& fun,
33
+ const std::vector<array>& primals,
34
+ const std::vector<array>& cotangents);
35
+
36
+ /**
37
+ * Computes the output and vector-Jacobian product (VJP) of a unary function.
38
+ */
39
+ std::pair<array, array> vjp(
40
+ const std::function<array(const array&)>& fun,
41
+ const array& primal,
42
+ const array& cotangent);
43
+
44
+ /**
45
+ * Computes the output and Jacobian-vector product (JVP) of a function.
46
+ *
47
+ * Computes the Jacobian-vector product of the Jacobian of the function
48
+ * evaluated at the primals with the vector of tangents. Returns a pair of
49
+ * vectors of output arrays and JVP arrays.
50
+ **/
51
+ std::pair<std::vector<array>, std::vector<array>> jvp(
52
+ const std::function<std::vector<array>(const std::vector<array>&)>& fun,
53
+ const std::vector<array>& primals,
54
+ const std::vector<array>& tangents);
55
+
56
+ /**
57
+ * Computes the output and Jacobian-vector product (JVP) of a unary function.
58
+ */
59
+ std::pair<array, array> jvp(
60
+ const std::function<array(const array&)>& fun,
61
+ const array& primal,
62
+ const array& tangent);
63
+
64
+ // Return type of general value_and_grad: a function which takes an input
65
+ // vector of arrays and returns a pair of vectors of arrays one for the
66
+ // values and one for the gradients wrt the first value.
67
+ using ValueAndGradFn =
68
+ std::function<std::pair<std::vector<array>, std::vector<array>>(
69
+ const std::vector<array>&)>;
70
+ using SimpleValueAndGradFn = std::function<std::pair<array, std::vector<array>>(
71
+ const std::vector<array>&)>;
72
+
73
+ /**
74
+ * Returns a function which computes the value and gradient of the input
75
+ * function with respect to a vector of input arrays.
76
+ **/
77
+ ValueAndGradFn value_and_grad(
78
+ const std::function<std::vector<array>(const std::vector<array>&)>& fun,
79
+ const std::vector<int>& argnums);
80
+
81
+ /**
82
+ * Returns a function which computes the value and gradient of the input
83
+ * function with respect to a single input array.
84
+ **/
85
+ ValueAndGradFn inline value_and_grad(
86
+ const std::function<std::vector<array>(const std::vector<array>&)>& fun,
87
+ int argnum = 0) {
88
+ return value_and_grad(fun, std::vector<int>{argnum});
89
+ }
90
+
91
+ /**
92
+ * Returns a function which computes the value and gradient of the unary
93
+ * input function.
94
+ **/
95
+ std::function<std::pair<array, array>(const array&)> inline value_and_grad(
96
+ const std::function<array(const array&)>& fun) {
97
+ return [fun](auto inputs) { return vjp(fun, inputs, array(1.0f)); };
98
+ }
99
+
100
+ SimpleValueAndGradFn inline value_and_grad(
101
+ const std::function<array(const std::vector<array>&)>& fun,
102
+ const std::vector<int>& argnums) {
103
+ return [fun, argnums](auto inputs) {
104
+ auto result = value_and_grad(
105
+ [fun](auto inputs) { return std::vector<array>{fun(inputs)}; },
106
+ argnums)(inputs);
107
+
108
+ return std::make_pair(result.first[0], result.second);
109
+ };
110
+ }
111
+
112
+ SimpleValueAndGradFn inline value_and_grad(
113
+ const std::function<array(const std::vector<array>&)>& fun,
114
+ int argnum = 0) {
115
+ return value_and_grad(fun, std::vector<int>{argnum});
116
+ }
117
+
118
+ /**
119
+ * Returns a function which computes the gradient of the input function with
120
+ * respect to a vector of input arrays.
121
+ *
122
+ * The function being differentiated takes a vector of arrays and returns an
123
+ * array. The vector of `argnums` specifies which the arguments to compute
124
+ * the gradient with respect to. At least one argument must be specified.
125
+ **/
126
+ std::function<std::vector<array>(const std::vector<array>&)> inline grad(
127
+ const std::function<array(const std::vector<array>&)>& fun,
128
+ const std::vector<int>& argnums) {
129
+ auto fn = value_and_grad(fun, argnums);
130
+ return [fn](const std::vector<array>& inputs) { return fn(inputs).second; };
131
+ }
132
+
133
+ /**
134
+ * Returns a function which computes the gradient of the input function with
135
+ * respect to a single input array.
136
+ *
137
+ * The function being differentiated takes a vector of arrays and returns an
138
+ * array. The optional `argnum` index specifies which the argument to compute
139
+ * the gradient with respect to and defaults to 0.
140
+ **/
141
+ std::function<std::vector<array>(const std::vector<array>&)> inline grad(
142
+ const std::function<array(const std::vector<array>&)>& fun,
143
+ int argnum = 0) {
144
+ return grad(fun, std::vector<int>{argnum});
145
+ }
146
+
147
+ /**
148
+ * Returns a function which computes the gradient of the unary input function.
149
+ **/
150
+ std::function<array(const array&)> inline grad(
151
+ const std::function<array(const array&)>& fun) {
152
+ auto fn = value_and_grad(fun);
153
+ return [fn](const array& input) { return fn(input).second; };
154
+ }
155
+
156
+ /**
157
+ * Automatically vectorize a unary function over the requested axes.
158
+ */
159
+ std::function<array(const array&)> vmap(
160
+ const std::function<array(const array&)>& fun,
161
+ int in_axis = 0,
162
+ int out_axis = 0);
163
+
164
+ /**
165
+ * Automatically vectorize a binary function over the requested axes.
166
+ */
167
+ std::function<array(const array&, const array&)> vmap(
168
+ const std::function<array(const array&, const array&)>& fun,
169
+ int in_axis_a = 0,
170
+ int in_axis_b = 0,
171
+ int out_axis = 0);
172
+
173
+ /**
174
+ * Automatically vectorize a function over the requested axes.
175
+ *
176
+ * The input function to `vmap` takes as an argument a vector of arrays and
177
+ * returns a vector of arrays. Optionally specify the axes to vectorize over
178
+ * with `in_axes` and `out_axes`, otherwise a default of 0 is used.
179
+ * Returns a vectorized function with the same signature as the input
180
+ * function.
181
+ */
182
+ std::function<std::vector<array>(const std::vector<array>&)> vmap(
183
+ const std::function<std::vector<array>(const std::vector<array>&)>& fun,
184
+ const std::vector<int>& in_axes = {},
185
+ const std::vector<int>& out_axes = {});
186
+
187
+ } // namespace mlx::core
lib/python3.11/site-packages/mlx/include/mlx/transforms_impl.h ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ namespace mlx::core::detail {
4
+
5
+ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
6
+ const std::function<std::vector<array>(const std::vector<array>&)>& fun,
7
+ const std::vector<array>& inputs,
8
+ const std::vector<int>& in_axes);
9
+
10
+ std::vector<array> vmap_replace(
11
+ const std::vector<array>& inputs,
12
+ const std::vector<array>& s_inputs,
13
+ const std::vector<array>& s_outputs,
14
+ const std::vector<int>& in_axes,
15
+ const std::vector<int>& out_axes);
16
+
17
+ } // namespace mlx::core::detail
lib/python3.11/site-packages/mlx/include/mlx/types/bf16.h ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <algorithm>
6
+ #include <cmath>
7
+ #include <cstdint>
8
+ #include <vector>
9
+
10
+ #define __MLX_BFLOAT_NAN__ 0x7FC0
11
+
12
+ namespace mlx::core {
13
+
14
+ namespace {
15
+ union float_bits_bf16 {
16
+ float f;
17
+ uint32_t u;
18
+ };
19
+ } // namespace
20
+
21
+ struct _MLX_BFloat16 {
22
+ uint16_t bits_;
23
+
24
+ // Default constructor
25
+ _MLX_BFloat16() = default;
26
+
27
+ // Default copy constructor
28
+ _MLX_BFloat16(_MLX_BFloat16 const&) = default;
29
+
30
+ // Appease std::vector<bool> for being special
31
+ _MLX_BFloat16& operator=(std::vector<bool>::reference x) {
32
+ bits_ = x;
33
+ return *this;
34
+ }
35
+
36
+ _MLX_BFloat16& operator=(const float& x) {
37
+ return (*this = _MLX_BFloat16(x));
38
+ }
39
+
40
+ // From float32
41
+ _MLX_BFloat16(const float& x) {
42
+ if (std::isnan(x)) {
43
+ bits_ = __MLX_BFLOAT_NAN__;
44
+ } else {
45
+ // Union
46
+ float_bits_bf16 in;
47
+
48
+ // Take bits
49
+ in.f = x;
50
+
51
+ // Round to nearest even
52
+ in.u += (in.u >> 16 & 1) + uint32_t(0x7FFF);
53
+
54
+ // Take upper 16 bits
55
+ bits_ = in.u >> 16;
56
+ }
57
+ }
58
+
59
+ // To float32
60
+ operator float() const {
61
+ // Union
62
+ float_bits_bf16 out;
63
+
64
+ // Upper 16 bits are the data and lower 16 bits are 0s
65
+ out.u = ((uint32_t)bits_) << 16;
66
+
67
+ return out.f;
68
+ }
69
+ };
70
+
71
+ #define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
72
+ inline otype __operator__(atype lhs, btype rhs) { \
73
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
74
+ }
75
+
76
+ #define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
77
+ inline otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
78
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
79
+ } \
80
+ inline otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
81
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
82
+ }
83
+
84
+ // Operators
85
+ #define bfloat_binop(_op_, _operator_) \
86
+ bfloat_binop_base( \
87
+ _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
88
+ bfloat_binop_helper(_op_, _operator_, float, float, float); \
89
+ bfloat_binop_helper(_op_, _operator_, double, double, double); \
90
+ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, bool, float); \
91
+ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
92
+ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
93
+ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
94
+ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
95
+
96
+ bfloat_binop(+, operator+);
97
+ bfloat_binop(-, operator-);
98
+ bfloat_binop(*, operator*);
99
+ bfloat_binop(/, operator/);
100
+
101
+ #undef bfloat_binop
102
+
103
+ // Comparison ops
104
+ #define bfloat_compop(__op__, __operator__) \
105
+ bfloat_binop_base( \
106
+ __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
107
+ bfloat_binop_helper(__op__, __operator__, bool, float, float); \
108
+ bfloat_binop_helper(__op__, __operator__, bool, double, double); \
109
+ bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
110
+ bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
111
+ bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
112
+ bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
113
+
114
+ bfloat_compop(>, operator>);
115
+ bfloat_compop(<, operator<);
116
+ bfloat_compop(>=, operator>=);
117
+ bfloat_compop(<=, operator<=);
118
+ bfloat_compop(==, operator==);
119
+ bfloat_compop(!=, operator!=);
120
+
121
+ #undef bfloat_compop
122
+
123
+ // Negative
124
+ inline _MLX_BFloat16 operator-(_MLX_BFloat16 lhs) {
125
+ return -static_cast<float>(lhs);
126
+ }
127
+
128
+ // Inplace ops
129
+ #define bfloat_inplace_op(__op__, __operator__) \
130
+ inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, const float& rhs) { \
131
+ lhs = lhs __op__ rhs; \
132
+ return lhs; \
133
+ } \
134
+ inline float& __operator__(float& lhs, _MLX_BFloat16 rhs) { \
135
+ lhs = lhs __op__ rhs; \
136
+ return lhs; \
137
+ }
138
+
139
+ bfloat_inplace_op(+, operator+=);
140
+ bfloat_inplace_op(-, operator-=);
141
+ bfloat_inplace_op(*, operator*=);
142
+ bfloat_inplace_op(/, operator/=);
143
+
144
+ #undef bfloat_inplace_op
145
+
146
+ // Bitwise ops
147
+
148
+ #define bfloat_bitop(__op__, __operator__) \
149
+ inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, _MLX_BFloat16 rhs) { \
150
+ _MLX_BFloat16 out; \
151
+ out.bits_ = lhs.bits_ __op__ rhs.bits_; \
152
+ return out; \
153
+ } \
154
+ inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, uint16_t rhs) { \
155
+ _MLX_BFloat16 out; \
156
+ out.bits_ = lhs.bits_ __op__ rhs; \
157
+ return out; \
158
+ } \
159
+ inline _MLX_BFloat16 __operator__(uint16_t lhs, _MLX_BFloat16 rhs) { \
160
+ _MLX_BFloat16 out; \
161
+ out.bits_ = lhs __op__ rhs.bits_; \
162
+ return out; \
163
+ }
164
+
165
+ bfloat_bitop(|, operator|);
166
+ bfloat_bitop(&, operator&);
167
+ bfloat_bitop(^, operator^);
168
+
169
+ #undef bfloat_bitop
170
+
171
+ #define bfloat_inplace_bitop(__op__, __operator__) \
172
+ inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
173
+ lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \
174
+ return lhs; \
175
+ } \
176
+ inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, uint16_t rhs) { \
177
+ lhs.bits_ = lhs.bits_ __op__ rhs; \
178
+ return lhs; \
179
+ }
180
+
181
+ bfloat_inplace_bitop(|, operator|=);
182
+ bfloat_inplace_bitop(&, operator&=);
183
+ bfloat_inplace_bitop(^, operator^=);
184
+
185
+ #undef bfloat_inplace_bitop
186
+
187
+ } // namespace mlx::core
lib/python3.11/site-packages/mlx/include/mlx/types/complex.h ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+ #include <complex>
5
+ #include "mlx/types/half_types.h"
6
+
7
+ namespace mlx::core {
8
+
9
+ struct complex64_t;
10
+
11
+ template <typename T>
12
+ static constexpr bool can_convert_to_complex64 =
13
+ !std::is_same_v<T, complex64_t> && std::is_convertible_v<T, float>;
14
+
15
+ struct complex64_t : public std::complex<float> {
16
+ complex64_t(float v, float u) : std::complex<float>(v, u){};
17
+ complex64_t(std::complex<float> v) : std::complex<float>(v){};
18
+
19
+ template <
20
+ typename T,
21
+ typename = typename std::enable_if<can_convert_to_complex64<T>>::type>
22
+ complex64_t(T x) : std::complex<float>(x){};
23
+
24
+ operator float() const {
25
+ return real();
26
+ };
27
+ };
28
+
29
+ inline bool operator>=(const complex64_t& a, const complex64_t& b) {
30
+ return (a.real() > b.real()) ||
31
+ (a.real() == b.real() && a.imag() >= b.imag());
32
+ }
33
+
34
+ inline bool operator>(const complex64_t& a, const complex64_t& b) {
35
+ return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());
36
+ }
37
+
38
+ inline bool operator<=(const complex64_t& a, const complex64_t& b) {
39
+ return operator>=(b, a);
40
+ }
41
+
42
+ inline bool operator<(const complex64_t& a, const complex64_t& b) {
43
+ return operator>(b, a);
44
+ }
45
+
46
+ inline complex64_t operator-(const complex64_t& v) {
47
+ return -static_cast<std::complex<float>>(v);
48
+ }
49
+
50
+ // clang-format off
51
+ #define complex_binop_helper(_op_, _operator_, itype) \
52
+ inline complex64_t _operator_(itype x, const complex64_t& y) { \
53
+ return x _op_ static_cast<std::complex<float>>(y); \
54
+ } \
55
+ inline complex64_t _operator_(const complex64_t& x, itype y) { \
56
+ return static_cast<std::complex<float>>(x) _op_ y; \
57
+ }
58
+
59
+ #define complex_binop(_op_, _operator_) \
60
+ inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \
61
+ return static_cast<std::complex<float>>(x) \
62
+ _op_ static_cast<std::complex<float>>(y); \
63
+ } \
64
+ complex_binop_helper(_op_, _operator_, bool) \
65
+ complex_binop_helper(_op_, _operator_, uint32_t) \
66
+ complex_binop_helper(_op_, _operator_, uint64_t) \
67
+ complex_binop_helper(_op_, _operator_, int32_t) \
68
+ complex_binop_helper(_op_, _operator_, int64_t) \
69
+ complex_binop_helper(_op_, _operator_, float16_t) \
70
+ complex_binop_helper(_op_, _operator_, bfloat16_t) \
71
+ complex_binop_helper(_op_, _operator_, const std::complex<float>&) \
72
+ complex_binop_helper(_op_, _operator_, float)
73
+ // clang-format on
74
+
75
+ complex_binop(+, operator+)
76
+
77
+ } // namespace mlx::core
lib/python3.11/site-packages/mlx/include/mlx/types/fp16.h ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <algorithm>
6
+ #include <cmath>
7
+ #include <cstdint>
8
+ #include <vector>
9
+
10
+ #define __MLX_HALF_NAN__ 0x7D00
11
+
12
+ namespace mlx::core {
13
+
14
+ namespace {
15
+ union float_bits_fp16 {
16
+ float f;
17
+ uint32_t u;
18
+ };
19
+ } // namespace
20
+
21
+ struct _MLX_Float16 {
22
+ uint16_t bits_;
23
+
24
+ // Default constructor
25
+ _MLX_Float16() = default;
26
+
27
+ // Default copy constructor
28
+ _MLX_Float16(_MLX_Float16 const&) = default;
29
+
30
+ // Appease std::vector<bool> for being special
31
+ _MLX_Float16& operator=(std::vector<bool>::reference x) {
32
+ bits_ = x;
33
+ return *this;
34
+ }
35
+
36
+ _MLX_Float16& operator=(const float& x) {
37
+ return (*this = _MLX_Float16(x));
38
+ }
39
+
40
+ // From float32
41
+ _MLX_Float16(const float& x) : bits_(0) {
42
+ // Conversion following
43
+ // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h
44
+
45
+ // Union
46
+ float_bits_fp16 in;
47
+
48
+ // Take fp32 bits
49
+ in.f = x;
50
+
51
+ // Find and take sign bit
52
+ uint32_t x_sign_32 = in.u & uint32_t(0x80000000);
53
+ uint16_t x_sign_16 = (x_sign_32 >> 16);
54
+
55
+ if (std::isnan(x)) {
56
+ bits_ = x_sign_16 | uint16_t(__MLX_HALF_NAN__);
57
+ } else {
58
+ // Union
59
+ float_bits_fp16 inf_scale, zero_scale, magic_bits;
60
+
61
+ // Find exponent bits and take the max supported by half
62
+ uint32_t x_expo_32 = in.u & uint32_t(0x7f800000);
63
+ uint32_t max_expo_32 = uint32_t(0x38800000);
64
+ x_expo_32 = x_expo_32 < max_expo_32 ? max_expo_32 : x_expo_32;
65
+ x_expo_32 += uint32_t(15) << 23;
66
+
67
+ // Handle scaling to inf as needed
68
+ inf_scale.u = uint32_t(0x77800000);
69
+ zero_scale.u = uint32_t(0x08800000);
70
+
71
+ // Combine with magic and let addition do rounding
72
+ magic_bits.u = x_expo_32;
73
+ magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f;
74
+
75
+ // Take the lower 5 bits of the exponent
76
+ uint32_t x_expo_16 = ((magic_bits.u >> 13) & uint32_t(0x7c00));
77
+
78
+ // Collect the lower 12 bits which have the mantissa
79
+ uint32_t x_mant_16 = magic_bits.u & uint32_t(0x0fff);
80
+
81
+ // Combine sign, exp and mantissa
82
+ bits_ = (x_sign_16 | uint16_t(x_expo_16 + x_mant_16));
83
+ }
84
+ }
85
+
86
+ // To float32
87
+ operator float() const {
88
+ // Conversion following
89
+ // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h
90
+
91
+ // Union
92
+ float_bits_fp16 out;
93
+
94
+ uint32_t x_sign_32 = (bits_ << 16) & uint32_t(0x80000000);
95
+ uint32_t base = (bits_ << 16);
96
+ uint32_t two_base = base + base;
97
+
98
+ uint32_t denorm_max = 1u << 27;
99
+ if (two_base < denorm_max) {
100
+ out.u = uint32_t(126) << 23; // magic mask
101
+ out.u |= (two_base >> 17); // Bits from fp16
102
+ out.f -= 0.5f; // magic bias
103
+ } else {
104
+ out.u = uint32_t(0xE0) << 23; // exponent offset
105
+ out.u += (two_base >> 4); // Bits from fp16
106
+ float out_unscaled = out.f; // Store value
107
+ out.u = uint32_t(0x7800000); // exponent scale
108
+ out.f *= out_unscaled;
109
+ }
110
+
111
+ // Add sign
112
+ out.u |= x_sign_32;
113
+
114
+ return out.f;
115
+ }
116
+ };
117
+
118
+ #define half_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
119
+ inline otype __operator__(atype lhs, btype rhs) { \
120
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
121
+ }
122
+
123
+ #define half_binop_helper(__op__, __operator__, otype, itype, ctype) \
124
+ inline otype __operator__(_MLX_Float16 lhs, itype rhs) { \
125
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
126
+ } \
127
+ inline otype __operator__(itype lhs, _MLX_Float16 rhs) { \
128
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
129
+ }
130
+
131
+ // Operators
132
+ #define half_binop(__op__, __operator__) \
133
+ half_binop_base( \
134
+ __op__, __operator__, _MLX_Float16, _MLX_Float16, _MLX_Float16, float); \
135
+ half_binop_helper(__op__, __operator__, float, float, float); \
136
+ half_binop_helper(__op__, __operator__, double, double, double); \
137
+ half_binop_helper(__op__, __operator__, _MLX_Float16, bool, float); \
138
+ half_binop_helper(__op__, __operator__, _MLX_Float16, int32_t, float); \
139
+ half_binop_helper(__op__, __operator__, _MLX_Float16, uint32_t, float); \
140
+ half_binop_helper(__op__, __operator__, _MLX_Float16, int64_t, float); \
141
+ half_binop_helper(__op__, __operator__, _MLX_Float16, uint64_t, float);
142
+
143
+ half_binop(+, operator+);
144
+ half_binop(-, operator-);
145
+ half_binop(*, operator*);
146
+ half_binop(/, operator/);
147
+
148
+ #undef half_binop
149
+
150
+ // Comparison ops
151
+ #define half_compop(__op__, __operator__) \
152
+ half_binop_base( \
153
+ __op__, __operator__, bool, _MLX_Float16, _MLX_Float16, float); \
154
+ half_binop_helper(__op__, __operator__, bool, float, float); \
155
+ half_binop_helper(__op__, __operator__, bool, double, double); \
156
+ half_binop_helper(__op__, __operator__, bool, int32_t, float); \
157
+ half_binop_helper(__op__, __operator__, bool, uint32_t, float); \
158
+ half_binop_helper(__op__, __operator__, bool, int64_t, float); \
159
+ half_binop_helper(__op__, __operator__, bool, uint64_t, float);
160
+
161
+ half_compop(>, operator>);
162
+ half_compop(<, operator<);
163
+ half_compop(>=, operator>=);
164
+ half_compop(<=, operator<=);
165
+ half_compop(==, operator==);
166
+ half_compop(!=, operator!=);
167
+
168
+ #undef half_compop
169
+
170
+ // Negative
171
+ inline _MLX_Float16 operator-(_MLX_Float16 lhs) {
172
+ return -static_cast<float>(lhs);
173
+ }
174
+
175
+ // Inplace ops
176
+ #define half_inplace_op(__op__, __operator__) \
177
+ inline _MLX_Float16& __operator__(_MLX_Float16& lhs, const float& rhs) { \
178
+ lhs = lhs __op__ rhs; \
179
+ return lhs; \
180
+ } \
181
+ inline float& __operator__(float& lhs, _MLX_Float16 rhs) { \
182
+ lhs = lhs __op__ rhs; \
183
+ return lhs; \
184
+ }
185
+
186
+ half_inplace_op(+, operator+=);
187
+ half_inplace_op(-, operator-=);
188
+ half_inplace_op(*, operator*=);
189
+ half_inplace_op(/, operator/=);
190
+
191
+ #undef half_inplace_op
192
+
193
+ // Bitwise ops
194
+
195
+ #define half_bitop(__op__, __operator__) \
196
+ inline _MLX_Float16 __operator__(_MLX_Float16 lhs, _MLX_Float16 rhs) { \
197
+ _MLX_Float16 out; \
198
+ out.bits_ = lhs.bits_ __op__ rhs.bits_; \
199
+ return out; \
200
+ } \
201
+ inline _MLX_Float16 __operator__(_MLX_Float16 lhs, uint16_t rhs) { \
202
+ _MLX_Float16 out; \
203
+ out.bits_ = lhs.bits_ __op__ rhs; \
204
+ return out; \
205
+ } \
206
+ inline _MLX_Float16 __operator__(uint16_t lhs, _MLX_Float16 rhs) { \
207
+ _MLX_Float16 out; \
208
+ out.bits_ = lhs __op__ rhs.bits_; \
209
+ return out; \
210
+ }
211
+
212
+ half_bitop(|, operator|);
213
+ half_bitop(&, operator&);
214
+ half_bitop(^, operator^);
215
+
216
+ #undef half_bitop
217
+
218
+ #define half_inplace_bitop(__op__, __operator__) \
219
+ inline _MLX_Float16& __operator__(_MLX_Float16& lhs, _MLX_Float16 rhs) { \
220
+ lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \
221
+ return lhs; \
222
+ } \
223
+ inline _MLX_Float16& __operator__(_MLX_Float16& lhs, uint16_t rhs) { \
224
+ lhs.bits_ = lhs.bits_ __op__ rhs; \
225
+ return lhs; \
226
+ }
227
+
228
+ half_inplace_bitop(|, operator|=);
229
+ half_inplace_bitop(&, operator&=);
230
+ half_inplace_bitop(^, operator^=);
231
+
232
+ #undef half_inplace_bitop
233
+
234
+ } // namespace mlx::core
lib/python3.11/site-packages/mlx/include/mlx/types/half_types.h ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+ #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
5
+
6
+ #include <arm_fp16.h>
7
+ namespace mlx::core {
8
+ typedef __fp16 float16_t;
9
+ } // namespace mlx::core
10
+
11
+ #else
12
+
13
+ #define ADD_HALF_BINOPS
14
+ #include "mlx/types/fp16.h"
15
+ namespace mlx::core {
16
+ typedef struct _MLX_Float16 float16_t;
17
+ } // namespace mlx::core
18
+
19
+ #endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
20
+ #ifdef __ARM_FEATURE_BF16
21
+
22
+ #include <arm_bf16.h>
23
+ namespace mlx::core {
24
+ typedef __bf16 bfloat16_t;
25
+ } // namespace mlx::core
26
+
27
+ #else
28
+
29
+ #define ADD_HALF_BINOPS
30
+ #include "mlx/types/bf16.h"
31
+ namespace mlx::core {
32
+ typedef struct _MLX_BFloat16 bfloat16_t;
33
+ } // namespace mlx::core
34
+
35
+ #endif // __ARM_FEATURE_BF16
36
+
37
+ #ifdef ADD_HALF_BINOPS
38
+ namespace mlx::core {
39
+
40
+ // clang-format off
41
+ #define fp16_bf16_binop_helper(__op__, __operator__) \
42
+ inline float __operator__(float16_t lhs, bfloat16_t rhs) { \
43
+ return static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
44
+ } \
45
+ inline float __operator__(bfloat16_t lhs, float16_t rhs) { \
46
+ return static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
47
+ }
48
+
49
+ fp16_bf16_binop_helper(+, operator+)
50
+ fp16_bf16_binop_helper(-, operator-)
51
+ fp16_bf16_binop_helper(*, operator*)
52
+ fp16_bf16_binop_helper(/, operator/)
53
+ // clang-format on
54
+
55
+ } // namespace mlx::core
56
+ #endif
lib/python3.11/site-packages/mlx/include/mlx/utils.h ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "array.h"
6
+ #include "device.h"
7
+ #include "dtype.h"
8
+ #include "stream.h"
9
+
10
+ namespace mlx::core {
11
+
12
+ /** The type from promoting the arrays' types with one another. */
13
+ Dtype result_type(const std::vector<array>& arrays);
14
+
15
+ std::vector<int> broadcast_shapes(
16
+ const std::vector<int>& s1,
17
+ const std::vector<int>& s2);
18
+
19
+ bool is_same_shape(const std::vector<array>& arrays);
20
+
21
+ /**
22
+ * Returns the axis normalized to be in the range [0, ndim).
23
+ * Based on numpy's normalize_axis_index. See
24
+ * https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html
25
+ */
26
+ int normalize_axis(int axis, int ndim);
27
+
28
+ std::ostream& operator<<(std::ostream& os, const Device& d);
29
+ std::ostream& operator<<(std::ostream& os, const Stream& s);
30
+ std::ostream& operator<<(std::ostream& os, const Dtype& d);
31
+ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
32
+ std::ostream& operator<<(std::ostream& os, array a);
33
+ std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
34
+ std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v);
35
+ inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
36
+ return os << v.real() << (v.imag() > 0 ? "+" : "") << v.imag() << "j";
37
+ }
38
+ inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
39
+ return os << static_cast<float>(v);
40
+ }
41
+ inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
42
+ return os << static_cast<float>(v);
43
+ }
44
+ } // namespace mlx::core
lib/python3.11/site-packages/mlx/lib/libmlx.dylib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8abefe46a1f39c92b28464814f05a730fa9899b17757703403c6ef362f06ac93
3
+ size 12420704
lib/python3.11/site-packages/mlx/lib/mlx.metallib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2eedf41000ed270283da11d889bb101aa4c88c6f8f0ec68fe6b040a5be424501
3
+ size 59495531
lib/python3.11/site-packages/mlx/nn/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ from mlx.nn import losses
4
+ from mlx.nn.layers import *
5
+ from mlx.nn.utils import value_and_grad
lib/python3.11/site-packages/mlx/nn/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (381 Bytes). View file
 
lib/python3.11/site-packages/mlx/nn/__pycache__/losses.cpython-311.pyc ADDED
Binary file (15.3 kB). View file
 
lib/python3.11/site-packages/mlx/nn/__pycache__/utils.cpython-311.pyc ADDED
Binary file (1.78 kB). View file
 
lib/python3.11/site-packages/mlx/nn/layers/__init__.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ from mlx.nn.layers.activations import (
4
+ CELU,
5
+ ELU,
6
+ GELU,
7
+ SELU,
8
+ Hardswish,
9
+ LeakyReLU,
10
+ LogSigmoid,
11
+ LogSoftmax,
12
+ Mish,
13
+ PReLU,
14
+ ReLU,
15
+ ReLU6,
16
+ SiLU,
17
+ Softmax,
18
+ Softplus,
19
+ Softsign,
20
+ Step,
21
+ Tanh,
22
+ celu,
23
+ elu,
24
+ gelu,
25
+ gelu_approx,
26
+ gelu_fast_approx,
27
+ hardswish,
28
+ leaky_relu,
29
+ log_sigmoid,
30
+ log_softmax,
31
+ mish,
32
+ prelu,
33
+ relu,
34
+ relu6,
35
+ selu,
36
+ silu,
37
+ softmax,
38
+ softplus,
39
+ softsign,
40
+ step,
41
+ tanh,
42
+ )
43
+ from mlx.nn.layers.base import Module
44
+ from mlx.nn.layers.containers import Sequential
45
+ from mlx.nn.layers.convolution import Conv1d, Conv2d
46
+ from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
47
+ from mlx.nn.layers.embedding import Embedding
48
+ from mlx.nn.layers.linear import Bilinear, Identity, Linear
49
+ from mlx.nn.layers.normalization import (
50
+ BatchNorm,
51
+ GroupNorm,
52
+ InstanceNorm,
53
+ LayerNorm,
54
+ RMSNorm,
55
+ )
56
+ from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
57
+ from mlx.nn.layers.quantized import QuantizedLinear
58
+ from mlx.nn.layers.transformer import (
59
+ MultiHeadAttention,
60
+ Transformer,
61
+ TransformerEncoder,
62
+ TransformerEncoderLayer,
63
+ )
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.52 kB). View file
 
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/activations.cpython-311.pyc ADDED
Binary file (20.9 kB). View file
 
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/base.cpython-311.pyc ADDED
Binary file (28.1 kB). View file
 
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/containers.cpython-311.pyc ADDED
Binary file (1.47 kB). View file
 
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/convolution.cpython-311.pyc ADDED
Binary file (6.36 kB). View file
 
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/dropout.cpython-311.pyc ADDED
Binary file (6.71 kB). View file
 
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/embedding.cpython-311.pyc ADDED
Binary file (2.07 kB). View file
 
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/linear.cpython-311.pyc ADDED
Binary file (6.92 kB). View file
 
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/normalization.cpython-311.pyc ADDED
Binary file (17.7 kB). View file
 
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/positional_encoding.cpython-311.pyc ADDED
Binary file (10.6 kB). View file
 
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/quantized.cpython-311.pyc ADDED
Binary file (6.34 kB). View file
 
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/transformer.cpython-311.pyc ADDED
Binary file (18 kB). View file