18256559666a70d7f16bb789b379220578acb9d0f204be811e183cc0a99d467b
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/defines.h +16 -0
- lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/erf.h +70 -0
- lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/gemm/conv.h +481 -0
- lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/gemm/gemm.h +538 -0
- lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/reduce.h +176 -0
- lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/utils.h +246 -0
- lib/python3.11/site-packages/mlx/include/mlx/backend/metal/matmul.h +31 -0
- lib/python3.11/site-packages/mlx/include/mlx/backend/metal/metal.h +31 -0
- lib/python3.11/site-packages/mlx/include/mlx/backend/metal/mps/gemm.h +370 -0
- lib/python3.11/site-packages/mlx/include/mlx/backend/metal/utils.h +169 -0
- lib/python3.11/site-packages/mlx/include/mlx/device.h +29 -0
- lib/python3.11/site-packages/mlx/include/mlx/dtype.h +105 -0
- lib/python3.11/site-packages/mlx/include/mlx/fft.h +151 -0
- lib/python3.11/site-packages/mlx/include/mlx/graph_utils.h +23 -0
- lib/python3.11/site-packages/mlx/include/mlx/io/load.h +114 -0
- lib/python3.11/site-packages/mlx/include/mlx/io/safetensor.h +32 -0
- lib/python3.11/site-packages/mlx/include/mlx/linalg.h +63 -0
- lib/python3.11/site-packages/mlx/include/mlx/mlx.h +14 -0
- lib/python3.11/site-packages/mlx/include/mlx/ops.h +1094 -0
- lib/python3.11/site-packages/mlx/include/mlx/primitives.h +1636 -0
- lib/python3.11/site-packages/mlx/include/mlx/random.h +193 -0
- lib/python3.11/site-packages/mlx/include/mlx/scheduler.h +173 -0
- lib/python3.11/site-packages/mlx/include/mlx/stream.h +32 -0
- lib/python3.11/site-packages/mlx/include/mlx/transforms.h +187 -0
- lib/python3.11/site-packages/mlx/include/mlx/transforms_impl.h +17 -0
- lib/python3.11/site-packages/mlx/include/mlx/types/bf16.h +187 -0
- lib/python3.11/site-packages/mlx/include/mlx/types/complex.h +77 -0
- lib/python3.11/site-packages/mlx/include/mlx/types/fp16.h +234 -0
- lib/python3.11/site-packages/mlx/include/mlx/types/half_types.h +56 -0
- lib/python3.11/site-packages/mlx/include/mlx/utils.h +44 -0
- lib/python3.11/site-packages/mlx/lib/libmlx.dylib +3 -0
- lib/python3.11/site-packages/mlx/lib/mlx.metallib +3 -0
- lib/python3.11/site-packages/mlx/nn/__init__.py +5 -0
- lib/python3.11/site-packages/mlx/nn/__pycache__/__init__.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/__pycache__/losses.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/__pycache__/utils.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__init__.py +63 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/__init__.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/activations.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/base.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/containers.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/convolution.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/dropout.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/embedding.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/linear.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/normalization.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/positional_encoding.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/quantized.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/transformer.cpython-311.pyc +0 -0
.gitattributes
CHANGED
@@ -35,3 +35,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
lib/python3.11/site-packages/llvmlite/binding/libllvmlite.dylib filter=lfs diff=lfs merge=lfs -text
|
37 |
lib/python3.11/site-packages/mlx/core.cpython-311-darwin.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
lib/python3.11/site-packages/llvmlite/binding/libllvmlite.dylib filter=lfs diff=lfs merge=lfs -text
|
37 |
lib/python3.11/site-packages/mlx/core.cpython-311-darwin.so filter=lfs diff=lfs merge=lfs -text
|
38 |
+
lib/python3.11/site-packages/mlx/lib/libmlx.dylib filter=lfs diff=lfs merge=lfs -text
|
39 |
+
lib/python3.11/site-packages/mlx/lib/mlx.metallib filter=lfs diff=lfs merge=lfs -text
|
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/defines.h
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#ifdef __METAL__
|
6 |
+
#define MTL_CONST constant
|
7 |
+
#else
|
8 |
+
#define MTL_CONST
|
9 |
+
#endif
|
10 |
+
|
11 |
+
static MTL_CONST constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
|
12 |
+
static MTL_CONST constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
|
13 |
+
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
|
14 |
+
static MTL_CONST constexpr int REDUCE_N_READS = 16;
|
15 |
+
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
|
16 |
+
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
|
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/erf.h
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include <metal_math>
|
6 |
+
|
7 |
+
/*
|
8 |
+
* Approximation to the error function.
|
9 |
+
* Based on code from:
|
10 |
+
* https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199
|
11 |
+
*/
|
12 |
+
float erf(float a) {
|
13 |
+
float r, s, t, u;
|
14 |
+
t = metal::abs(a);
|
15 |
+
s = a * a;
|
16 |
+
if (t > 0.927734375f) {
|
17 |
+
// maximum error 0.99527 ulp
|
18 |
+
r = metal::fma(
|
19 |
+
-1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
|
20 |
+
u = metal::fma(
|
21 |
+
-3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
|
22 |
+
r = metal::fma(r, s, u);
|
23 |
+
r = metal::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
|
24 |
+
r = metal::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
|
25 |
+
r = metal::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
|
26 |
+
r = metal::fma(r, t, -t);
|
27 |
+
// TODO, replace with expm1 when implemented
|
28 |
+
r = 1.0f - metal::exp(r);
|
29 |
+
r = metal::copysign(r, a);
|
30 |
+
} else {
|
31 |
+
// maximum error 0.98929 ulp
|
32 |
+
r = -5.96761703e-4f; // -0x1.38e000p-11
|
33 |
+
r = metal::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
|
34 |
+
r = metal::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
|
35 |
+
r = metal::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
|
36 |
+
r = metal::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
|
37 |
+
r = metal::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
|
38 |
+
r = metal::fma(r, a, a);
|
39 |
+
}
|
40 |
+
return r;
|
41 |
+
}
|
42 |
+
|
43 |
+
float erfinv(float a) {
|
44 |
+
auto t = metal::fma(a, 0.0f - a, 1.0f);
|
45 |
+
t = metal::log(t);
|
46 |
+
float p;
|
47 |
+
if (metal::abs(t) > 6.125f) { // maximum ulp error = 2.35793
|
48 |
+
p = 3.03697567e-10f; // 0x1.4deb44p-32
|
49 |
+
p = metal::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
|
50 |
+
p = metal::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
|
51 |
+
p = metal::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
|
52 |
+
p = metal::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
|
53 |
+
p = metal::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
|
54 |
+
p = metal::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
|
55 |
+
p = metal::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
|
56 |
+
p = metal::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
|
57 |
+
} else { // maximum ulp error = 2.35002
|
58 |
+
p = 5.43877832e-9f; // 0x1.75c000p-28
|
59 |
+
p = metal::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
|
60 |
+
p = metal::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
|
61 |
+
p = metal::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
|
62 |
+
p = metal::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
|
63 |
+
p = metal::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
|
64 |
+
p = metal::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
|
65 |
+
p = metal::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
|
66 |
+
p = metal::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
|
67 |
+
p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
|
68 |
+
}
|
69 |
+
return a * p;
|
70 |
+
}
|
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/gemm/conv.h
ADDED
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include <metal_simdgroup>
|
6 |
+
#include <metal_simdgroup_matrix>
|
7 |
+
#include <metal_stdlib>
|
8 |
+
|
9 |
+
#include "mlx/backend/metal/kernels/bf16.h"
|
10 |
+
#include "mlx/backend/metal/kernels/conv_params.h"
|
11 |
+
|
12 |
+
#define MLX_MTL_CONST static constant constexpr const
|
13 |
+
|
14 |
+
using namespace metal;
|
15 |
+
|
16 |
+
///////////////////////////////////////////////////////////////////////////////
|
17 |
+
// Loading helper
|
18 |
+
///////////////////////////////////////////////////////////////////////////////
|
19 |
+
|
20 |
+
template <
|
21 |
+
typename T,
|
22 |
+
int BM,
|
23 |
+
int BN,
|
24 |
+
int BK,
|
25 |
+
int vec_size,
|
26 |
+
int tgp_size,
|
27 |
+
int tgp_padding = 0>
|
28 |
+
struct Conv2DInputBlockLoader {
|
29 |
+
// Destination dimensions
|
30 |
+
MLX_MTL_CONST int dst_fd = BM;
|
31 |
+
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
|
32 |
+
MLX_MTL_CONST int n_vecs = BK / vec_size;
|
33 |
+
|
34 |
+
// Stride along block row within the block
|
35 |
+
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
|
36 |
+
MLX_MTL_CONST int n_rows = dst_fd / bstride;
|
37 |
+
|
38 |
+
// Thread location indices
|
39 |
+
const short thread_idx;
|
40 |
+
const short bi;
|
41 |
+
const short bj;
|
42 |
+
|
43 |
+
// threadgroup and device memory
|
44 |
+
threadgroup T* dst;
|
45 |
+
const device T* src;
|
46 |
+
|
47 |
+
const constant MLXConvParams<2>& params;
|
48 |
+
|
49 |
+
int weight_h;
|
50 |
+
int weight_w;
|
51 |
+
|
52 |
+
int offsets_n[n_rows];
|
53 |
+
int offsets_oh[n_rows];
|
54 |
+
int offsets_ow[n_rows];
|
55 |
+
|
56 |
+
/* Constructor */
|
57 |
+
METAL_FUNC Conv2DInputBlockLoader(
|
58 |
+
const device T* src_,
|
59 |
+
threadgroup T* dst_,
|
60 |
+
const constant MLXConvParams<2>& params_,
|
61 |
+
uint3 tid [[threadgroup_position_in_grid]],
|
62 |
+
uint3 lid [[thread_position_in_threadgroup]],
|
63 |
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
64 |
+
uint simd_lane_id [[thread_index_in_simdgroup]])
|
65 |
+
: thread_idx(simd_group_id * 32 + simd_lane_id),
|
66 |
+
bi(thread_idx / n_vecs),
|
67 |
+
bj(vec_size * (thread_idx % n_vecs)),
|
68 |
+
dst(dst_ + bi * dst_ld + bj),
|
69 |
+
src(src_ + bj),
|
70 |
+
params(params_),
|
71 |
+
weight_h(0),
|
72 |
+
weight_w(0) {
|
73 |
+
int out_n_pixels = params.oS[0] * params.oS[1];
|
74 |
+
|
75 |
+
for (int i = 0; i < n_rows; ++i) {
|
76 |
+
int offset_nhw = tid.y * BM + bi + i * bstride;
|
77 |
+
offsets_n[i] = offset_nhw / out_n_pixels;
|
78 |
+
int hw = offset_nhw % out_n_pixels;
|
79 |
+
offsets_oh[i] = hw / params.oS[1];
|
80 |
+
offsets_ow[i] = hw % params.oS[1];
|
81 |
+
}
|
82 |
+
|
83 |
+
(void)lid;
|
84 |
+
}
|
85 |
+
|
86 |
+
/* Load from device memory into threadgroup memory - without bound checking */
|
87 |
+
METAL_FUNC void load_unsafe() const {
|
88 |
+
#pragma clang loop unroll(full)
|
89 |
+
for (short i = 0, is = 0; i < n_rows; ++i, is += bstride) {
|
90 |
+
int n = offsets_n[i];
|
91 |
+
int oh = offsets_oh[i];
|
92 |
+
int ow = offsets_ow[i];
|
93 |
+
|
94 |
+
int ih = oh * params.str[0] - params.pad[0] + weight_h * params.dil[0];
|
95 |
+
int iw = ow * params.str[1] - params.pad[1] + weight_w * params.dil[1];
|
96 |
+
|
97 |
+
// Read from input if in bounds
|
98 |
+
if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {
|
99 |
+
const device T* curr_src = src + n * params.in_strides[0] +
|
100 |
+
ih * params.in_strides[1] + iw * params.in_strides[2];
|
101 |
+
|
102 |
+
#pragma clang loop unroll(full)
|
103 |
+
for (short j = 0; j < vec_size; ++j) {
|
104 |
+
dst[is * dst_ld + j] = curr_src[j];
|
105 |
+
}
|
106 |
+
}
|
107 |
+
|
108 |
+
// Zero pad otherwise
|
109 |
+
else {
|
110 |
+
#pragma clang loop unroll(full)
|
111 |
+
for (short j = 0; j < vec_size; ++j) {
|
112 |
+
dst[is * dst_ld + j] = T(0);
|
113 |
+
}
|
114 |
+
}
|
115 |
+
}
|
116 |
+
}
|
117 |
+
|
118 |
+
/* Iteration helper */
|
119 |
+
METAL_FUNC void next() {
|
120 |
+
if (++weight_w < params.wS[1]) {
|
121 |
+
return;
|
122 |
+
}
|
123 |
+
|
124 |
+
weight_w = 0;
|
125 |
+
|
126 |
+
if (++weight_h < params.wS[0]) {
|
127 |
+
return;
|
128 |
+
}
|
129 |
+
|
130 |
+
weight_h = 0;
|
131 |
+
|
132 |
+
src += BK;
|
133 |
+
}
|
134 |
+
};
|
135 |
+
|
136 |
+
template <
|
137 |
+
typename T,
|
138 |
+
int BM,
|
139 |
+
int BN,
|
140 |
+
int BK,
|
141 |
+
int vec_size,
|
142 |
+
int tgp_size,
|
143 |
+
int tgp_padding = 0>
|
144 |
+
struct Conv2DWeightBlockLoader {
|
145 |
+
// Destination dimensions
|
146 |
+
MLX_MTL_CONST int dst_fd = BN;
|
147 |
+
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
|
148 |
+
MLX_MTL_CONST int n_vecs = BK / vec_size;
|
149 |
+
|
150 |
+
// Stride along block row within the block
|
151 |
+
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
|
152 |
+
MLX_MTL_CONST int n_rows = dst_fd / bstride;
|
153 |
+
|
154 |
+
// Leading dimension for src
|
155 |
+
const int src_ld;
|
156 |
+
|
157 |
+
// Thread location indices
|
158 |
+
const short thread_idx;
|
159 |
+
const short bi;
|
160 |
+
const short bj;
|
161 |
+
|
162 |
+
// threadgroup and device memory
|
163 |
+
threadgroup T* dst;
|
164 |
+
const device T* src;
|
165 |
+
|
166 |
+
const constant MLXConvParams<2>& params;
|
167 |
+
|
168 |
+
int weight_h;
|
169 |
+
int weight_w;
|
170 |
+
|
171 |
+
/* Constructor */
|
172 |
+
METAL_FUNC Conv2DWeightBlockLoader(
|
173 |
+
const device T* src_,
|
174 |
+
threadgroup T* dst_,
|
175 |
+
const constant MLXConvParams<2>& params_,
|
176 |
+
uint3 tid [[threadgroup_position_in_grid]],
|
177 |
+
uint3 lid [[thread_position_in_threadgroup]],
|
178 |
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
179 |
+
uint simd_lane_id [[thread_index_in_simdgroup]])
|
180 |
+
: src_ld(params_.wt_strides[0]),
|
181 |
+
thread_idx(simd_group_id * 32 + simd_lane_id),
|
182 |
+
bi(thread_idx / n_vecs),
|
183 |
+
bj(vec_size * (thread_idx % n_vecs)),
|
184 |
+
dst(dst_ + bi * dst_ld + bj),
|
185 |
+
src(src_ + bi * src_ld + bj),
|
186 |
+
params(params_),
|
187 |
+
weight_h(0),
|
188 |
+
weight_w(0) {
|
189 |
+
(void)lid;
|
190 |
+
(void)tid;
|
191 |
+
}
|
192 |
+
|
193 |
+
/* Load from device memory into threadgroup memory - without bound checking */
|
194 |
+
METAL_FUNC void load_unsafe() const {
|
195 |
+
const device T* curr_src =
|
196 |
+
src + weight_h * params.wt_strides[1] + weight_w * params.wt_strides[2];
|
197 |
+
#pragma clang loop unroll(full)
|
198 |
+
for (short i = 0; i < dst_fd; i += bstride) {
|
199 |
+
#pragma clang loop unroll(full)
|
200 |
+
for (short j = 0; j < vec_size; j++) {
|
201 |
+
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
202 |
+
}
|
203 |
+
}
|
204 |
+
}
|
205 |
+
|
206 |
+
/* Iteration helper */
|
207 |
+
METAL_FUNC void next() {
|
208 |
+
if (++weight_w < params.wS[1]) {
|
209 |
+
return;
|
210 |
+
}
|
211 |
+
|
212 |
+
weight_w = 0;
|
213 |
+
|
214 |
+
if (++weight_h < params.wS[0]) {
|
215 |
+
return;
|
216 |
+
}
|
217 |
+
|
218 |
+
weight_h = 0;
|
219 |
+
|
220 |
+
src += BK;
|
221 |
+
}
|
222 |
+
};
|
223 |
+
|
224 |
+
///////////////////////////////////////////////////////////////////////////////
|
225 |
+
// Transforms
|
226 |
+
///////////////////////////////////////////////////////////////////////////////
|
227 |
+
|
228 |
+
template <typename OutT, typename InT>
|
229 |
+
struct TransformNone {
|
230 |
+
static METAL_FUNC OutT apply(InT x) {
|
231 |
+
return static_cast<OutT>(x);
|
232 |
+
}
|
233 |
+
};
|
234 |
+
|
235 |
+
template <typename T>
|
236 |
+
struct AccumHelper {
|
237 |
+
typedef float accum_type;
|
238 |
+
};
|
239 |
+
|
240 |
+
///////////////////////////////////////////////////////////////////////////////
|
241 |
+
// MMA helper
|
242 |
+
///////////////////////////////////////////////////////////////////////////////
|
243 |
+
|
244 |
+
template <
|
245 |
+
typename T,
|
246 |
+
int BM,
|
247 |
+
int BN,
|
248 |
+
int BK,
|
249 |
+
int WM,
|
250 |
+
int WN,
|
251 |
+
bool transpose_a,
|
252 |
+
bool transpose_b,
|
253 |
+
int tgp_padding_a = 0,
|
254 |
+
int tgp_padding_b = 0,
|
255 |
+
typename AccumType = typename AccumHelper<T>::accum_type,
|
256 |
+
typename Epilogue = TransformNone<T, AccumType>>
|
257 |
+
struct Conv2DBlockMMA {
|
258 |
+
// Warp tile size along M
|
259 |
+
MLX_MTL_CONST int TM = BM / (WM * 8);
|
260 |
+
// Warp tile size along N
|
261 |
+
MLX_MTL_CONST int TN = BN / (WN * 8);
|
262 |
+
|
263 |
+
// Warp tile simdgroup matrix strides along M
|
264 |
+
MLX_MTL_CONST int TM_stride = 8 * WM;
|
265 |
+
// Warp tile simdgroup matrix strides along M
|
266 |
+
MLX_MTL_CONST int TN_stride = 8 * WN;
|
267 |
+
|
268 |
+
// Leading dimensions of threadgroup A, B blocks
|
269 |
+
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
|
270 |
+
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
|
271 |
+
|
272 |
+
// Strides of A, B along reduction axis
|
273 |
+
MLX_MTL_CONST short simd_stride_a =
|
274 |
+
transpose_a ? TM_stride : TM_stride * lda_tgp;
|
275 |
+
MLX_MTL_CONST short simd_stride_b =
|
276 |
+
transpose_b ? TN_stride * ldb_tgp : TN_stride;
|
277 |
+
|
278 |
+
// Jump between elements
|
279 |
+
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
|
280 |
+
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
|
281 |
+
|
282 |
+
// Offsets within threadgroup
|
283 |
+
const int tm;
|
284 |
+
const int tn;
|
285 |
+
|
286 |
+
// Simdgroup matrices
|
287 |
+
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
288 |
+
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
289 |
+
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
290 |
+
simdgroup_matrix<AccumType, 8, 8>(0)};
|
291 |
+
|
292 |
+
short sm;
|
293 |
+
short sn;
|
294 |
+
|
295 |
+
/* Constructor */
|
296 |
+
METAL_FUNC Conv2DBlockMMA(
|
297 |
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
298 |
+
uint simd_lane_id [[thread_index_in_simdgroup]])
|
299 |
+
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
300 |
+
short qid = simd_lane_id / 4;
|
301 |
+
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
302 |
+
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
303 |
+
}
|
304 |
+
|
305 |
+
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
306 |
+
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
307 |
+
// Iterate over BK in blocks of 8
|
308 |
+
#pragma clang loop unroll(full)
|
309 |
+
for (short kk = 0; kk < BK; kk += 8) {
|
310 |
+
short2 offset_a =
|
311 |
+
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
|
312 |
+
short2 offset_b =
|
313 |
+
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
|
314 |
+
|
315 |
+
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
|
316 |
+
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
|
317 |
+
|
318 |
+
simdgroup_barrier(mem_flags::mem_none);
|
319 |
+
// Load elements from threadgroup A as simdgroup matrices
|
320 |
+
#pragma clang loop unroll(full)
|
321 |
+
for (short i = 0; i < TM; i++) {
|
322 |
+
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
|
323 |
+
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
|
324 |
+
As__ += simd_stride_a;
|
325 |
+
}
|
326 |
+
|
327 |
+
simdgroup_barrier(mem_flags::mem_none);
|
328 |
+
// Load elements from threadgroup B as simdgroup matrices
|
329 |
+
#pragma clang loop unroll(full)
|
330 |
+
for (short j = 0; j < TN; j++) {
|
331 |
+
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
|
332 |
+
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
|
333 |
+
Bs__ += simd_stride_b;
|
334 |
+
}
|
335 |
+
|
336 |
+
simdgroup_barrier(mem_flags::mem_none);
|
337 |
+
// Multiply and accumulate into result simdgroup matrices
|
338 |
+
#pragma clang loop unroll(full)
|
339 |
+
for (short i = 0; i < TM; i++) {
|
340 |
+
#pragma clang loop unroll(full)
|
341 |
+
for (short j = 0; j < TN; j++) {
|
342 |
+
simdgroup_multiply_accumulate(
|
343 |
+
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
|
344 |
+
}
|
345 |
+
}
|
346 |
+
}
|
347 |
+
}
|
348 |
+
|
349 |
+
/* Store results from simdgroup_matrix results into device memory */
|
350 |
+
METAL_FUNC void store_result(device T* C, const int ldc) const {
|
351 |
+
#pragma clang loop unroll(full)
|
352 |
+
for (int i = 0; i < TM; i++) {
|
353 |
+
#pragma clang loop unroll(full)
|
354 |
+
for (int j = 0; j < TN; j++) {
|
355 |
+
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
|
356 |
+
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
357 |
+
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
|
358 |
+
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
359 |
+
}
|
360 |
+
}
|
361 |
+
}
|
362 |
+
|
363 |
+
METAL_FUNC void
|
364 |
+
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
|
365 |
+
#pragma clang loop unroll(full)
|
366 |
+
for (int i = 0; i < TM; i++) {
|
367 |
+
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
|
368 |
+
#pragma clang loop unroll(full)
|
369 |
+
for (int j = 0; j < TN; j++) {
|
370 |
+
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
|
371 |
+
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
|
372 |
+
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
373 |
+
}
|
374 |
+
|
375 |
+
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
|
376 |
+
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
|
377 |
+
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
378 |
+
}
|
379 |
+
}
|
380 |
+
}
|
381 |
+
}
|
382 |
+
}
|
383 |
+
};
|
384 |
+
|
385 |
+
///////////////////////////////////////////////////////////////////////////////
|
386 |
+
// GEMM kernels
|
387 |
+
///////////////////////////////////////////////////////////////////////////////
|
388 |
+
|
389 |
+
template <
|
390 |
+
typename T,
|
391 |
+
int BM,
|
392 |
+
int BN,
|
393 |
+
int BK,
|
394 |
+
int WM,
|
395 |
+
int WN,
|
396 |
+
bool transpose_a,
|
397 |
+
bool transpose_b,
|
398 |
+
typename AccumType = typename AccumHelper<T>::accum_type,
|
399 |
+
typename Epilogue = TransformNone<T, AccumType>>
|
400 |
+
struct Conv2DImplicitGEMMKernel {
|
401 |
+
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
|
402 |
+
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
|
403 |
+
MLX_MTL_CONST short tgp_mem_size_a =
|
404 |
+
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
405 |
+
MLX_MTL_CONST short tgp_mem_size_b =
|
406 |
+
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
407 |
+
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
408 |
+
|
409 |
+
MLX_MTL_CONST short tgp_size = WM * WN * 32;
|
410 |
+
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
|
411 |
+
|
412 |
+
using loader_a_t =
|
413 |
+
Conv2DInputBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_a>;
|
414 |
+
using loader_b_t =
|
415 |
+
Conv2DWeightBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_b>;
|
416 |
+
using mma_t = Conv2DBlockMMA<
|
417 |
+
T,
|
418 |
+
BM,
|
419 |
+
BN,
|
420 |
+
BK,
|
421 |
+
WM,
|
422 |
+
WN,
|
423 |
+
transpose_a,
|
424 |
+
transpose_b,
|
425 |
+
tgp_padding_a,
|
426 |
+
tgp_padding_b,
|
427 |
+
AccumType,
|
428 |
+
Epilogue>;
|
429 |
+
|
430 |
+
/* Main kernel function */
|
431 |
+
static METAL_FUNC void run(
|
432 |
+
const device T* A [[buffer(0)]],
|
433 |
+
const device T* B [[buffer(1)]],
|
434 |
+
device T* C [[buffer(2)]],
|
435 |
+
const constant MLXConvParams<2>& params [[buffer(3)]],
|
436 |
+
threadgroup T* tgp_memory [[threadgroup(0)]],
|
437 |
+
uint3 tid [[threadgroup_position_in_grid]],
|
438 |
+
uint3 lid [[thread_position_in_threadgroup]],
|
439 |
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
440 |
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
441 |
+
const int c_row = tid.y * BM;
|
442 |
+
const int c_col = tid.x * BN;
|
443 |
+
const int K = params.wt_strides[0];
|
444 |
+
const int N = params.O;
|
445 |
+
|
446 |
+
B += c_col * K;
|
447 |
+
C += c_row * N + c_col;
|
448 |
+
|
449 |
+
// Prepare threadgroup memory for loading
|
450 |
+
threadgroup T* As = tgp_memory;
|
451 |
+
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
|
452 |
+
|
453 |
+
// Prepare threadgroup loading operations
|
454 |
+
loader_a_t loader_a(A, As, params, tid, lid, simd_gid, simd_lid);
|
455 |
+
loader_b_t loader_b(B, Bs, params, tid, lid, simd_gid, simd_lid);
|
456 |
+
|
457 |
+
// Prepare threadgroup mma operation
|
458 |
+
mma_t mma_op(simd_gid, simd_lid);
|
459 |
+
|
460 |
+
for (int k = 0; k < K; k += BK) {
|
461 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
462 |
+
// Load elements into threadgroup
|
463 |
+
loader_a.load_unsafe();
|
464 |
+
loader_b.load_unsafe();
|
465 |
+
|
466 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
467 |
+
|
468 |
+
// Multiply and accumulate threadgroup elements
|
469 |
+
mma_op.mma(As, Bs);
|
470 |
+
|
471 |
+
// Prepare for next iteration
|
472 |
+
loader_a.next();
|
473 |
+
loader_b.next();
|
474 |
+
}
|
475 |
+
|
476 |
+
threadgroup_barrier(mem_flags::mem_none);
|
477 |
+
|
478 |
+
// Store results to device memory
|
479 |
+
mma_op.store_result(C, N);
|
480 |
+
}
|
481 |
+
};
|
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/gemm/gemm.h
ADDED
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include <metal_simdgroup>
|
6 |
+
#include <metal_simdgroup_matrix>
|
7 |
+
#include <metal_stdlib>
|
8 |
+
|
9 |
+
#define MLX_MTL_CONST static constant constexpr const
|
10 |
+
|
11 |
+
using namespace metal;
|
12 |
+
|
13 |
+
///////////////////////////////////////////////////////////////////////////////
|
14 |
+
// Loading helper
|
15 |
+
///////////////////////////////////////////////////////////////////////////////
|
16 |
+
|
17 |
+
template <
|
18 |
+
typename T,
|
19 |
+
int BROWS,
|
20 |
+
int BCOLS,
|
21 |
+
int BK,
|
22 |
+
int vec_size,
|
23 |
+
int tgp_size,
|
24 |
+
bool transpose,
|
25 |
+
bool ldK,
|
26 |
+
int tgp_padding = 0>
|
27 |
+
struct BlockLoader {
|
28 |
+
// Destination dimensions
|
29 |
+
MLX_MTL_CONST int dst_fd = transpose ? BCOLS : BROWS;
|
30 |
+
MLX_MTL_CONST int dst_ld = (transpose ? BROWS : BCOLS) + tgp_padding;
|
31 |
+
MLX_MTL_CONST int n_vecs = (transpose ? BROWS : BCOLS) / vec_size;
|
32 |
+
|
33 |
+
// Stride along block row within the block
|
34 |
+
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
|
35 |
+
|
36 |
+
// Leading dimension for src
|
37 |
+
const int src_ld;
|
38 |
+
// Stride along reduction axis between blocks
|
39 |
+
const int tstride;
|
40 |
+
|
41 |
+
// Thread location indices
|
42 |
+
const short thread_idx;
|
43 |
+
const short bi;
|
44 |
+
const short bj;
|
45 |
+
|
46 |
+
// threadgroup and device memory
|
47 |
+
threadgroup T* dst;
|
48 |
+
const device T* src;
|
49 |
+
|
50 |
+
/* Constructor */
|
51 |
+
METAL_FUNC BlockLoader(
|
52 |
+
const device T* src_,
|
53 |
+
const int src_ld_,
|
54 |
+
threadgroup T* dst_,
|
55 |
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
56 |
+
uint simd_lane_id [[thread_index_in_simdgroup]])
|
57 |
+
: src_ld(src_ld_),
|
58 |
+
tstride(
|
59 |
+
BK * ((int)(transpose ^ !ldK) * src_ld + (int)(transpose ^ ldK))),
|
60 |
+
thread_idx(simd_group_id * 32 + simd_lane_id),
|
61 |
+
bi(thread_idx / n_vecs),
|
62 |
+
bj(vec_size * (thread_idx % n_vecs)),
|
63 |
+
dst(dst_ + bi * dst_ld + bj),
|
64 |
+
src(src_ + bi * src_ld + bj) {}
|
65 |
+
|
66 |
+
/* Load from device memory into threadgroup memory - without bound checking */
|
67 |
+
METAL_FUNC void load_unsafe() const {
|
68 |
+
#pragma clang loop unroll(full)
|
69 |
+
for (short i = 0; i < dst_fd; i += bstride) {
|
70 |
+
#pragma clang loop unroll(full)
|
71 |
+
for (short j = 0; j < vec_size; j++) {
|
72 |
+
dst[i * dst_ld + j] = src[i * src_ld + j];
|
73 |
+
}
|
74 |
+
}
|
75 |
+
}
|
76 |
+
|
77 |
+
/* Load from device memory into threadgroup memory - with bound checking */
|
78 |
+
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
79 |
+
src_tile_dim = transpose ? src_tile_dim.yx : src_tile_dim.xy;
|
80 |
+
|
81 |
+
// Iterate over rows of block
|
82 |
+
#pragma clang loop unroll(full)
|
83 |
+
for (short i = 0; i < dst_fd; i += bstride) {
|
84 |
+
// Row is in bounds, we check against column
|
85 |
+
if ((bi + i) < src_tile_dim.y) {
|
86 |
+
// Use fast thread memory for bound checks
|
87 |
+
short tmp_idx[vec_size];
|
88 |
+
T tmp_val[vec_size];
|
89 |
+
|
90 |
+
// Make sure tmp_idx only contains valid indices
|
91 |
+
#pragma clang loop unroll(full)
|
92 |
+
for (short j = 0; j < vec_size; j++) {
|
93 |
+
tmp_idx[j] = bj + j < src_tile_dim.x ? j : 0;
|
94 |
+
}
|
95 |
+
|
96 |
+
// Read all valid indices into tmp_val
|
97 |
+
#pragma clang loop unroll(full)
|
98 |
+
for (short j = 0; j < vec_size; j++) {
|
99 |
+
tmp_val[j] = src[i * src_ld + tmp_idx[j]];
|
100 |
+
}
|
101 |
+
|
102 |
+
// Zero out unneeded values
|
103 |
+
#pragma clang loop unroll(full)
|
104 |
+
for (short j = 0; j < vec_size; j++) {
|
105 |
+
tmp_val[j] = bj + j < src_tile_dim.x ? tmp_val[j] : T(0);
|
106 |
+
}
|
107 |
+
|
108 |
+
// Copy values to threadgroup memory
|
109 |
+
#pragma clang loop unroll(full)
|
110 |
+
for (short j = 0; j < vec_size; j++) {
|
111 |
+
dst[i * dst_ld + j] = tmp_val[j];
|
112 |
+
}
|
113 |
+
}
|
114 |
+
|
115 |
+
// Row is out of bounds, we just fill tgp memory with zeros
|
116 |
+
else {
|
117 |
+
#pragma clang loop unroll(full)
|
118 |
+
for (short j = 0; j < vec_size; j++) {
|
119 |
+
dst[i * dst_ld + j] = T(0);
|
120 |
+
}
|
121 |
+
}
|
122 |
+
}
|
123 |
+
}
|
124 |
+
|
125 |
+
/* Iteration helper */
|
126 |
+
METAL_FUNC void next() {
|
127 |
+
src += tstride;
|
128 |
+
}
|
129 |
+
};
|
130 |
+
|
131 |
+
///////////////////////////////////////////////////////////////////////////////
|
132 |
+
// Transforms
|
133 |
+
///////////////////////////////////////////////////////////////////////////////
|
134 |
+
|
135 |
+
template <typename OutT, typename InT>
|
136 |
+
struct TransformNone {
|
137 |
+
static METAL_FUNC OutT apply(InT x) {
|
138 |
+
return static_cast<OutT>(x);
|
139 |
+
}
|
140 |
+
};
|
141 |
+
|
142 |
+
template <typename T>
|
143 |
+
struct AccumHelper {
|
144 |
+
typedef float accum_type;
|
145 |
+
};
|
146 |
+
|
147 |
+
///////////////////////////////////////////////////////////////////////////////
|
148 |
+
// MMA helper
|
149 |
+
///////////////////////////////////////////////////////////////////////////////
|
150 |
+
|
151 |
+
template <
|
152 |
+
typename T,
|
153 |
+
int BM,
|
154 |
+
int BN,
|
155 |
+
int BK,
|
156 |
+
int WM,
|
157 |
+
int WN,
|
158 |
+
bool transpose_a,
|
159 |
+
bool transpose_b,
|
160 |
+
int tgp_padding_a = 0,
|
161 |
+
int tgp_padding_b = 0,
|
162 |
+
typename AccumType = typename AccumHelper<T>::accum_type,
|
163 |
+
typename Epilogue = TransformNone<T, AccumType>>
|
164 |
+
struct BlockMMA {
|
165 |
+
// Warp tile size along M
|
166 |
+
MLX_MTL_CONST int TM = BM / (WM * 8);
|
167 |
+
// Warp tile size along N
|
168 |
+
MLX_MTL_CONST int TN = BN / (WN * 8);
|
169 |
+
|
170 |
+
// Warp tile simdgroup matrix strides along M
|
171 |
+
MLX_MTL_CONST int TM_stride = 8 * WM;
|
172 |
+
// Warp tile simdgroup matrix strides along M
|
173 |
+
MLX_MTL_CONST int TN_stride = 8 * WN;
|
174 |
+
|
175 |
+
// Leading dimensions of threadgroup A, B blocks
|
176 |
+
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
|
177 |
+
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
|
178 |
+
|
179 |
+
// Strides of A, B along reduction axis
|
180 |
+
MLX_MTL_CONST short simd_stride_a =
|
181 |
+
transpose_a ? TM_stride : TM_stride * lda_tgp;
|
182 |
+
MLX_MTL_CONST short simd_stride_b =
|
183 |
+
transpose_b ? TN_stride * ldb_tgp : TN_stride;
|
184 |
+
|
185 |
+
// Jump between elements
|
186 |
+
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
|
187 |
+
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
|
188 |
+
|
189 |
+
// Offsets within threadgroup
|
190 |
+
const int tm;
|
191 |
+
const int tn;
|
192 |
+
|
193 |
+
// Simdgroup matrices
|
194 |
+
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
195 |
+
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
196 |
+
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
197 |
+
simdgroup_matrix<AccumType, 8, 8>(0)};
|
198 |
+
|
199 |
+
short sm;
|
200 |
+
short sn;
|
201 |
+
|
202 |
+
/* Constructor */
|
203 |
+
METAL_FUNC BlockMMA(
|
204 |
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
205 |
+
uint simd_lane_id [[thread_index_in_simdgroup]])
|
206 |
+
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
207 |
+
short qid = simd_lane_id / 4;
|
208 |
+
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
209 |
+
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
210 |
+
}
|
211 |
+
|
212 |
+
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
213 |
+
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
214 |
+
// Iterate over BK in blocks of 8
|
215 |
+
#pragma clang loop unroll(full)
|
216 |
+
for (short kk = 0; kk < BK; kk += 8) {
|
217 |
+
short2 offset_a =
|
218 |
+
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
|
219 |
+
short2 offset_b =
|
220 |
+
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
|
221 |
+
|
222 |
+
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
|
223 |
+
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
|
224 |
+
|
225 |
+
simdgroup_barrier(mem_flags::mem_none);
|
226 |
+
// Load elements from threadgroup A as simdgroup matrices
|
227 |
+
#pragma clang loop unroll(full)
|
228 |
+
for (short i = 0; i < TM; i++) {
|
229 |
+
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
|
230 |
+
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
|
231 |
+
As__ += simd_stride_a;
|
232 |
+
}
|
233 |
+
|
234 |
+
simdgroup_barrier(mem_flags::mem_none);
|
235 |
+
// Load elements from threadgroup B as simdgroup matrices
|
236 |
+
#pragma clang loop unroll(full)
|
237 |
+
for (short j = 0; j < TN; j++) {
|
238 |
+
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
|
239 |
+
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
|
240 |
+
Bs__ += simd_stride_b;
|
241 |
+
}
|
242 |
+
|
243 |
+
simdgroup_barrier(mem_flags::mem_none);
|
244 |
+
// Multiply and accumulate into result simdgroup matrices
|
245 |
+
#pragma clang loop unroll(full)
|
246 |
+
for (short i = 0; i < TM; i++) {
|
247 |
+
#pragma clang loop unroll(full)
|
248 |
+
for (short j = 0; j < TN; j++) {
|
249 |
+
simdgroup_multiply_accumulate(
|
250 |
+
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
|
251 |
+
}
|
252 |
+
}
|
253 |
+
}
|
254 |
+
}
|
255 |
+
|
256 |
+
/* Store results from simdgroup_matrix results into device memory */
|
257 |
+
METAL_FUNC void store_result(device T* C, const int ldc) const {
|
258 |
+
#pragma clang loop unroll(full)
|
259 |
+
for (int i = 0; i < TM; i++) {
|
260 |
+
#pragma clang loop unroll(full)
|
261 |
+
for (int j = 0; j < TN; j++) {
|
262 |
+
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
|
263 |
+
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
264 |
+
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
|
265 |
+
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
266 |
+
}
|
267 |
+
}
|
268 |
+
}
|
269 |
+
|
270 |
+
METAL_FUNC void
|
271 |
+
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
|
272 |
+
#pragma clang loop unroll(full)
|
273 |
+
for (int i = 0; i < TM; i++) {
|
274 |
+
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
|
275 |
+
#pragma clang loop unroll(full)
|
276 |
+
for (int j = 0; j < TN; j++) {
|
277 |
+
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
|
278 |
+
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
|
279 |
+
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
280 |
+
}
|
281 |
+
|
282 |
+
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
|
283 |
+
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
|
284 |
+
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
285 |
+
}
|
286 |
+
}
|
287 |
+
}
|
288 |
+
}
|
289 |
+
}
|
290 |
+
};
|
291 |
+
|
292 |
+
///////////////////////////////////////////////////////////////////////////////
|
293 |
+
// GEMM kernels
|
294 |
+
///////////////////////////////////////////////////////////////////////////////
|
295 |
+
|
296 |
+
template <
|
297 |
+
typename T,
|
298 |
+
int BM,
|
299 |
+
int BN,
|
300 |
+
int BK,
|
301 |
+
int WM,
|
302 |
+
int WN,
|
303 |
+
bool transpose_a,
|
304 |
+
bool transpose_b,
|
305 |
+
bool MN_aligned,
|
306 |
+
bool K_aligned,
|
307 |
+
typename AccumType = typename AccumHelper<T>::accum_type,
|
308 |
+
typename Epilogue = TransformNone<T, AccumType>>
|
309 |
+
struct GEMMKernel {
|
310 |
+
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
|
311 |
+
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
|
312 |
+
MLX_MTL_CONST short tgp_mem_size_a =
|
313 |
+
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
314 |
+
MLX_MTL_CONST short tgp_mem_size_b =
|
315 |
+
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
316 |
+
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
317 |
+
|
318 |
+
MLX_MTL_CONST short tgp_size = WM * WN * 32;
|
319 |
+
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
|
320 |
+
|
321 |
+
using loader_a_t = BlockLoader<
|
322 |
+
T,
|
323 |
+
BM,
|
324 |
+
BK,
|
325 |
+
BK,
|
326 |
+
vec_size,
|
327 |
+
tgp_size,
|
328 |
+
transpose_a,
|
329 |
+
true,
|
330 |
+
tgp_padding_a>;
|
331 |
+
using loader_b_t = BlockLoader<
|
332 |
+
T,
|
333 |
+
BK,
|
334 |
+
BN,
|
335 |
+
BK,
|
336 |
+
vec_size,
|
337 |
+
tgp_size,
|
338 |
+
transpose_b,
|
339 |
+
false,
|
340 |
+
tgp_padding_b>;
|
341 |
+
using mma_t = BlockMMA<
|
342 |
+
T,
|
343 |
+
BM,
|
344 |
+
BN,
|
345 |
+
BK,
|
346 |
+
WM,
|
347 |
+
WN,
|
348 |
+
transpose_a,
|
349 |
+
transpose_b,
|
350 |
+
tgp_padding_a,
|
351 |
+
tgp_padding_b,
|
352 |
+
AccumType,
|
353 |
+
Epilogue>;
|
354 |
+
|
355 |
+
/* Main kernel function */
|
356 |
+
static METAL_FUNC void run(
|
357 |
+
const device T* A [[buffer(0)]],
|
358 |
+
const device T* B [[buffer(1)]],
|
359 |
+
device T* C [[buffer(2)]],
|
360 |
+
const constant int& M [[buffer(3)]],
|
361 |
+
const constant int& N [[buffer(4)]],
|
362 |
+
const constant int& K [[buffer(5)]],
|
363 |
+
const constant int& batch_stride_a [[buffer(6)]],
|
364 |
+
const constant int& batch_stride_b [[buffer(7)]],
|
365 |
+
const constant int& batch_stride_c [[buffer(8)]],
|
366 |
+
threadgroup T* tgp_memory [[threadgroup(0)]],
|
367 |
+
uint simd_lane_id [[thread_index_in_simdgroup]],
|
368 |
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
369 |
+
uint3 tid [[threadgroup_position_in_grid]],
|
370 |
+
uint3 lid [[thread_position_in_threadgroup]]) {
|
371 |
+
// Pacifying compiler
|
372 |
+
(void)lid;
|
373 |
+
|
374 |
+
// Adjust for batch
|
375 |
+
A += batch_stride_a * tid.z;
|
376 |
+
B += batch_stride_b * tid.z;
|
377 |
+
C += batch_stride_c * tid.z;
|
378 |
+
|
379 |
+
// Adjust for transpose
|
380 |
+
const int lda_dev = transpose_a ? M : K;
|
381 |
+
const int ldb_dev = transpose_b ? K : N;
|
382 |
+
|
383 |
+
// Find block in A, B, C
|
384 |
+
const int c_row = tid.y * BM;
|
385 |
+
const int c_col = tid.x * BN;
|
386 |
+
|
387 |
+
A += transpose_a ? c_row : c_row * K;
|
388 |
+
B += transpose_b ? c_col * K : c_col;
|
389 |
+
C += c_row * N + c_col;
|
390 |
+
|
391 |
+
// Prepare threadgroup memory for loading
|
392 |
+
threadgroup T* As = tgp_memory;
|
393 |
+
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
|
394 |
+
|
395 |
+
// Prepare threadgroup loading operations
|
396 |
+
loader_a_t loader_a(A, lda_dev, As, simd_group_id, simd_lane_id);
|
397 |
+
loader_b_t loader_b(B, ldb_dev, Bs, simd_group_id, simd_lane_id);
|
398 |
+
|
399 |
+
// Prepare threadgroup mma operation
|
400 |
+
mma_t mma_op(simd_group_id, simd_lane_id);
|
401 |
+
|
402 |
+
///////////////////////////////////////////////////////////////////////////////
|
403 |
+
// MNK aligned loop
|
404 |
+
if (MN_aligned && K_aligned) {
|
405 |
+
for (int k = 0; k < K; k += BK) {
|
406 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
407 |
+
// Load elements into threadgroup
|
408 |
+
loader_a.load_unsafe();
|
409 |
+
loader_b.load_unsafe();
|
410 |
+
|
411 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
412 |
+
|
413 |
+
// Multiply and accumulate threadgroup elements
|
414 |
+
mma_op.mma(As, Bs);
|
415 |
+
|
416 |
+
// Prepare for next iteration
|
417 |
+
loader_a.next();
|
418 |
+
loader_b.next();
|
419 |
+
}
|
420 |
+
|
421 |
+
threadgroup_barrier(mem_flags::mem_none);
|
422 |
+
|
423 |
+
// Store results to device memory
|
424 |
+
mma_op.store_result(C, N);
|
425 |
+
return;
|
426 |
+
|
427 |
+
}
|
428 |
+
///////////////////////////////////////////////////////////////////////////////
|
429 |
+
// MN aligned, K unaligned loop
|
430 |
+
else if (MN_aligned && !K_aligned) {
|
431 |
+
// Main loop
|
432 |
+
int k = 0;
|
433 |
+
for (; k + BK <= K; k += BK) {
|
434 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
435 |
+
// Load elements into threadgroup
|
436 |
+
loader_a.load_unsafe();
|
437 |
+
loader_b.load_unsafe();
|
438 |
+
|
439 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
440 |
+
|
441 |
+
// Multiply and accumulate threadgroup elements
|
442 |
+
mma_op.mma(As, Bs);
|
443 |
+
|
444 |
+
// Prepare for next iteration
|
445 |
+
loader_a.next();
|
446 |
+
loader_b.next();
|
447 |
+
}
|
448 |
+
|
449 |
+
// Loop tail
|
450 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
451 |
+
|
452 |
+
loader_a.load_safe(short2(K - k, BM));
|
453 |
+
loader_b.load_safe(short2(BN, K - k));
|
454 |
+
|
455 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
456 |
+
|
457 |
+
mma_op.mma(As, Bs);
|
458 |
+
|
459 |
+
// Store results to device memory
|
460 |
+
mma_op.store_result(C, N);
|
461 |
+
return;
|
462 |
+
|
463 |
+
}
|
464 |
+
///////////////////////////////////////////////////////////////////////////////
|
465 |
+
// MNK unaligned loop
|
466 |
+
else { // Loop over K - unaligned case
|
467 |
+
|
468 |
+
short2 src_tile_dims(min(BN, N - c_col), min(BM, M - c_row));
|
469 |
+
|
470 |
+
if (src_tile_dims.y == BM && src_tile_dims.x == BN) {
|
471 |
+
int k = 0;
|
472 |
+
for (; k + BK <= K; k += BK) {
|
473 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
474 |
+
// Load elements into threadgroup
|
475 |
+
loader_a.load_unsafe();
|
476 |
+
loader_b.load_unsafe();
|
477 |
+
|
478 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
479 |
+
|
480 |
+
// Multiply and accumulate threadgroup elements
|
481 |
+
mma_op.mma(As, Bs);
|
482 |
+
|
483 |
+
// Prepare for next iteration
|
484 |
+
loader_a.next();
|
485 |
+
loader_b.next();
|
486 |
+
}
|
487 |
+
|
488 |
+
threadgroup_barrier(mem_flags::mem_none);
|
489 |
+
|
490 |
+
if (k < K) {
|
491 |
+
loader_a.load_safe(short2(K - k, BM));
|
492 |
+
loader_b.load_safe(short2(BN, K - k));
|
493 |
+
|
494 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
495 |
+
|
496 |
+
mma_op.mma(As, Bs);
|
497 |
+
}
|
498 |
+
|
499 |
+
mma_op.store_result(C, N);
|
500 |
+
return;
|
501 |
+
|
502 |
+
} else {
|
503 |
+
int k = 0;
|
504 |
+
for (; k + BK <= K; k += BK) {
|
505 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
506 |
+
// Load elements into threadgroup
|
507 |
+
loader_a.load_safe(short2(BK, src_tile_dims.y));
|
508 |
+
loader_b.load_safe(short2(src_tile_dims.x, BK));
|
509 |
+
|
510 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
511 |
+
|
512 |
+
// Multiply and accumulate threadgroup elements
|
513 |
+
mma_op.mma(As, Bs);
|
514 |
+
|
515 |
+
// Prepare for next iteration
|
516 |
+
loader_a.next();
|
517 |
+
loader_b.next();
|
518 |
+
}
|
519 |
+
|
520 |
+
threadgroup_barrier(mem_flags::mem_none);
|
521 |
+
|
522 |
+
if (k < K) {
|
523 |
+
loader_a.load_safe(short2(K - k, src_tile_dims.y));
|
524 |
+
loader_b.load_safe(short2(src_tile_dims.x, K - k));
|
525 |
+
|
526 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
527 |
+
|
528 |
+
mma_op.mma(As, Bs);
|
529 |
+
}
|
530 |
+
|
531 |
+
threadgroup_barrier(mem_flags::mem_none);
|
532 |
+
mma_op.store_result_safe(C, N, src_tile_dims);
|
533 |
+
|
534 |
+
return;
|
535 |
+
}
|
536 |
+
}
|
537 |
+
}
|
538 |
+
};
|
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/reduce.h
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include <metal_atomic>
|
6 |
+
#include <metal_simdgroup>
|
7 |
+
|
8 |
+
#include "mlx/backend/metal/kernels/atomic.h"
|
9 |
+
#include "mlx/backend/metal/kernels/bf16.h"
|
10 |
+
#include "mlx/backend/metal/kernels/utils.h"
|
11 |
+
|
12 |
+
union bool4_or_uint {
|
13 |
+
bool4 b;
|
14 |
+
unsigned int i;
|
15 |
+
};
|
16 |
+
|
17 |
+
struct None {
|
18 |
+
template <typename T>
|
19 |
+
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
20 |
+
mlx_atomic_store_explicit(out, val, offset);
|
21 |
+
}
|
22 |
+
};
|
23 |
+
|
24 |
+
struct And {
|
25 |
+
bool simd_reduce(bool val) {
|
26 |
+
return simd_all(val);
|
27 |
+
};
|
28 |
+
|
29 |
+
static constexpr constant bool init = true;
|
30 |
+
|
31 |
+
void atomic_update(
|
32 |
+
device mlx_atomic<unsigned int>* out,
|
33 |
+
bool val,
|
34 |
+
int elem_idx,
|
35 |
+
int offset = 0) {
|
36 |
+
if (!val) {
|
37 |
+
bool4_or_uint update;
|
38 |
+
update.b = {true, true, true, true};
|
39 |
+
update.b[elem_idx] = false;
|
40 |
+
mlx_atomic_fetch_and_explicit(out, update.i, offset);
|
41 |
+
}
|
42 |
+
}
|
43 |
+
|
44 |
+
void atomic_update(device mlx_atomic<bool>* out, bool val, int offset = 0) {
|
45 |
+
if (!val) {
|
46 |
+
mlx_atomic_store_explicit(out, val, offset);
|
47 |
+
}
|
48 |
+
}
|
49 |
+
|
50 |
+
// Non atomic update
|
51 |
+
void update(device bool* out, bool val) {
|
52 |
+
*out &= val;
|
53 |
+
}
|
54 |
+
|
55 |
+
// Operator
|
56 |
+
bool operator()(bool a, bool b) {
|
57 |
+
return a && b;
|
58 |
+
}
|
59 |
+
};
|
60 |
+
|
61 |
+
struct Or {
|
62 |
+
bool simd_reduce(bool val) {
|
63 |
+
return simd_any(val);
|
64 |
+
};
|
65 |
+
|
66 |
+
static constexpr constant bool init = false;
|
67 |
+
|
68 |
+
void atomic_update(
|
69 |
+
device mlx_atomic<unsigned int>* out,
|
70 |
+
bool val,
|
71 |
+
int elem_idx,
|
72 |
+
int offset = 0) {
|
73 |
+
if (val) {
|
74 |
+
bool4_or_uint update;
|
75 |
+
update.b = {false, false, false, false};
|
76 |
+
update.b[elem_idx] = true;
|
77 |
+
mlx_atomic_fetch_or_explicit(out, update.i, offset);
|
78 |
+
}
|
79 |
+
}
|
80 |
+
|
81 |
+
void atomic_update(device mlx_atomic<bool>* out, bool val, int offset = 0) {
|
82 |
+
if (val) {
|
83 |
+
mlx_atomic_store_explicit(out, val, offset);
|
84 |
+
}
|
85 |
+
}
|
86 |
+
|
87 |
+
// Non atomic update
|
88 |
+
void update(device bool* out, bool val) {
|
89 |
+
*out |= val;
|
90 |
+
}
|
91 |
+
|
92 |
+
// Operator
|
93 |
+
bool operator()(bool a, bool b) {
|
94 |
+
return a || b;
|
95 |
+
}
|
96 |
+
};
|
97 |
+
|
98 |
+
template <typename U>
|
99 |
+
struct Sum {
|
100 |
+
template <typename T>
|
101 |
+
T simd_reduce(T val) {
|
102 |
+
return simd_sum(val);
|
103 |
+
};
|
104 |
+
|
105 |
+
static constexpr constant U init = U(0);
|
106 |
+
|
107 |
+
template <typename T>
|
108 |
+
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
109 |
+
mlx_atomic_fetch_add_explicit(out, val, offset);
|
110 |
+
}
|
111 |
+
|
112 |
+
// Operator
|
113 |
+
U operator()(U a, U b) {
|
114 |
+
return a + b;
|
115 |
+
}
|
116 |
+
};
|
117 |
+
|
118 |
+
template <typename U>
|
119 |
+
struct Prod {
|
120 |
+
template <typename T>
|
121 |
+
T simd_reduce(T val) {
|
122 |
+
return simd_product(val);
|
123 |
+
};
|
124 |
+
|
125 |
+
static constexpr constant U init = U(1);
|
126 |
+
|
127 |
+
template <typename T>
|
128 |
+
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
129 |
+
mlx_atomic_fetch_mul_explicit(out, val, offset);
|
130 |
+
}
|
131 |
+
|
132 |
+
// Operator
|
133 |
+
U operator()(U a, U b) {
|
134 |
+
return a * b;
|
135 |
+
}
|
136 |
+
};
|
137 |
+
|
138 |
+
template <typename U>
|
139 |
+
struct Min {
|
140 |
+
template <typename T>
|
141 |
+
T simd_reduce(T val) {
|
142 |
+
return simd_min(val);
|
143 |
+
};
|
144 |
+
|
145 |
+
static constexpr constant U init = Limits<U>::max;
|
146 |
+
|
147 |
+
template <typename T>
|
148 |
+
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
149 |
+
mlx_atomic_fetch_min_explicit(out, val, offset);
|
150 |
+
}
|
151 |
+
|
152 |
+
// Operator
|
153 |
+
U operator()(U a, U b) {
|
154 |
+
return a < b ? a : b;
|
155 |
+
}
|
156 |
+
};
|
157 |
+
|
158 |
+
template <typename U>
|
159 |
+
struct Max {
|
160 |
+
template <typename T>
|
161 |
+
T simd_reduce(T val) {
|
162 |
+
return simd_max(val);
|
163 |
+
};
|
164 |
+
|
165 |
+
static constexpr constant U init = Limits<U>::min;
|
166 |
+
|
167 |
+
template <typename T>
|
168 |
+
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
169 |
+
mlx_atomic_fetch_max_explicit(out, val, offset);
|
170 |
+
}
|
171 |
+
|
172 |
+
// Operator
|
173 |
+
U operator()(U a, U b) {
|
174 |
+
return a > b ? a : b;
|
175 |
+
}
|
176 |
+
};
|
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/utils.h
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include <metal_math>
|
6 |
+
#include "mlx/backend/metal/kernels/bf16.h"
|
7 |
+
#include "mlx/backend/metal/kernels/complex.h"
|
8 |
+
|
9 |
+
///////////////////////////////////////////////////////////////////////////////
|
10 |
+
// Type limits utils
|
11 |
+
///////////////////////////////////////////////////////////////////////////////
|
12 |
+
|
13 |
+
template <typename U>
|
14 |
+
struct Limits {
|
15 |
+
static const constant U max;
|
16 |
+
static const constant U min;
|
17 |
+
static const constant U finite_max;
|
18 |
+
static const constant U finite_min;
|
19 |
+
};
|
20 |
+
|
21 |
+
#define instantiate_default_limit(type) \
|
22 |
+
template <> \
|
23 |
+
struct Limits<type> { \
|
24 |
+
static constexpr constant type max = metal::numeric_limits<type>::max(); \
|
25 |
+
static constexpr constant type min = metal::numeric_limits<type>::min(); \
|
26 |
+
static constexpr constant type finite_max = \
|
27 |
+
metal::numeric_limits<type>::max(); \
|
28 |
+
static constexpr constant type finite_min = \
|
29 |
+
metal::numeric_limits<type>::min(); \
|
30 |
+
};
|
31 |
+
|
32 |
+
instantiate_default_limit(uint8_t);
|
33 |
+
instantiate_default_limit(uint16_t);
|
34 |
+
instantiate_default_limit(uint32_t);
|
35 |
+
instantiate_default_limit(uint64_t);
|
36 |
+
instantiate_default_limit(int8_t);
|
37 |
+
instantiate_default_limit(int16_t);
|
38 |
+
instantiate_default_limit(int32_t);
|
39 |
+
instantiate_default_limit(int64_t);
|
40 |
+
|
41 |
+
#define instantiate_float_limit(type) \
|
42 |
+
template <> \
|
43 |
+
struct Limits<type> { \
|
44 |
+
static constexpr constant type max = \
|
45 |
+
metal::numeric_limits<type>::infinity(); \
|
46 |
+
static constexpr constant type min = \
|
47 |
+
-metal::numeric_limits<type>::infinity(); \
|
48 |
+
static constexpr constant type finite_max = \
|
49 |
+
metal::numeric_limits<type>::max(); \
|
50 |
+
static constexpr constant type finite_min = \
|
51 |
+
-metal::numeric_limits<type>::max(); \
|
52 |
+
};
|
53 |
+
|
54 |
+
instantiate_float_limit(half);
|
55 |
+
instantiate_float_limit(float);
|
56 |
+
instantiate_float_limit(bfloat16_t);
|
57 |
+
|
58 |
+
template <>
|
59 |
+
struct Limits<bool> {
|
60 |
+
static constexpr constant bool max = true;
|
61 |
+
static constexpr constant bool min = false;
|
62 |
+
};
|
63 |
+
|
64 |
+
///////////////////////////////////////////////////////////////////////////////
|
65 |
+
// Indexing utils
|
66 |
+
///////////////////////////////////////////////////////////////////////////////
|
67 |
+
|
68 |
+
inline size_t elem_to_loc(
|
69 |
+
uint elem,
|
70 |
+
device const int* shape,
|
71 |
+
device const size_t* strides,
|
72 |
+
int ndim) {
|
73 |
+
size_t loc = 0;
|
74 |
+
for (int i = ndim - 1; i >= 0; --i) {
|
75 |
+
loc += (elem % shape[i]) * strides[i];
|
76 |
+
elem /= shape[i];
|
77 |
+
}
|
78 |
+
return loc;
|
79 |
+
}
|
80 |
+
|
81 |
+
inline size_t elem_to_loc(
|
82 |
+
uint elem,
|
83 |
+
constant const int* shape,
|
84 |
+
constant const size_t* strides,
|
85 |
+
int ndim) {
|
86 |
+
size_t loc = 0;
|
87 |
+
for (int i = ndim - 1; i >= 0; --i) {
|
88 |
+
loc += (elem % shape[i]) * strides[i];
|
89 |
+
elem /= shape[i];
|
90 |
+
}
|
91 |
+
return loc;
|
92 |
+
}
|
93 |
+
|
94 |
+
template <int NDIM>
|
95 |
+
inline uint2 elem_to_loc_2_nd(
|
96 |
+
uint3 elem,
|
97 |
+
constant const int shape[NDIM],
|
98 |
+
constant const size_t a_strides[NDIM],
|
99 |
+
constant const size_t b_strides[NDIM]) {
|
100 |
+
uint2 loc = {
|
101 |
+
static_cast<uint>(
|
102 |
+
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
103 |
+
static_cast<uint>(
|
104 |
+
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
|
105 |
+
for (int d = NDIM - 3; d >= 0; --d) {
|
106 |
+
uint l = elem.z % shape[d];
|
107 |
+
loc.x += l * a_strides[d];
|
108 |
+
loc.y += l * b_strides[d];
|
109 |
+
elem.z /= shape[d];
|
110 |
+
}
|
111 |
+
return loc;
|
112 |
+
}
|
113 |
+
|
114 |
+
template <int NDIM>
|
115 |
+
inline size_t elem_to_loc_nd(
|
116 |
+
uint3 elem,
|
117 |
+
constant const int shape[NDIM],
|
118 |
+
constant const size_t strides[NDIM]) {
|
119 |
+
size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
|
120 |
+
for (int d = NDIM - 3; d >= 0; --d) {
|
121 |
+
loc += (elem.z % shape[d]) * strides[d];
|
122 |
+
elem.z /= shape[d];
|
123 |
+
}
|
124 |
+
return loc;
|
125 |
+
}
|
126 |
+
|
127 |
+
inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) {
|
128 |
+
return elem * stride;
|
129 |
+
}
|
130 |
+
|
131 |
+
inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) {
|
132 |
+
return elem.x * strides[1] + elem.y * strides[0];
|
133 |
+
}
|
134 |
+
|
135 |
+
inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) {
|
136 |
+
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
137 |
+
}
|
138 |
+
|
139 |
+
// Non templated version to handle arbitrary dims
|
140 |
+
inline size_t elem_to_loc(
|
141 |
+
uint3 elem,
|
142 |
+
constant const int* shape,
|
143 |
+
constant const size_t* strides,
|
144 |
+
int ndim) {
|
145 |
+
size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
|
146 |
+
for (int d = ndim - 3; d >= 0; --d) {
|
147 |
+
loc += (elem.z % shape[d]) * strides[d];
|
148 |
+
elem.z /= shape[d];
|
149 |
+
}
|
150 |
+
return loc;
|
151 |
+
}
|
152 |
+
|
153 |
+
inline uint2 elem_to_loc_2_nd(
|
154 |
+
uint3 elem,
|
155 |
+
constant const int* shape,
|
156 |
+
constant const size_t* a_strides,
|
157 |
+
constant const size_t* b_strides,
|
158 |
+
int ndim) {
|
159 |
+
uint2 loc = {
|
160 |
+
static_cast<uint>(
|
161 |
+
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
162 |
+
static_cast<uint>(
|
163 |
+
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
|
164 |
+
for (int d = ndim - 3; d >= 0; --d) {
|
165 |
+
uint l = elem.z % shape[d];
|
166 |
+
loc.x += l * a_strides[d];
|
167 |
+
loc.y += l * b_strides[d];
|
168 |
+
elem.z /= shape[d];
|
169 |
+
}
|
170 |
+
return loc;
|
171 |
+
}
|
172 |
+
|
173 |
+
template <int NDIM>
|
174 |
+
inline uint elem_to_loc_nd(
|
175 |
+
uint elem,
|
176 |
+
device const int* shape,
|
177 |
+
device const size_t* strides);
|
178 |
+
|
179 |
+
template <>
|
180 |
+
inline uint elem_to_loc_nd<1>(
|
181 |
+
uint elem,
|
182 |
+
device const int* shape,
|
183 |
+
device const size_t* strides) {
|
184 |
+
return (elem % shape[0]) * strides[0];
|
185 |
+
}
|
186 |
+
|
187 |
+
template <>
|
188 |
+
inline uint elem_to_loc_nd<2>(
|
189 |
+
uint elem,
|
190 |
+
device const int* shape,
|
191 |
+
device const size_t* strides) {
|
192 |
+
uint loc = (elem % shape[1]) * strides[1];
|
193 |
+
elem /= shape[1];
|
194 |
+
loc += (elem % shape[0]) * strides[0];
|
195 |
+
return loc;
|
196 |
+
}
|
197 |
+
|
198 |
+
template <>
|
199 |
+
inline uint elem_to_loc_nd<3>(
|
200 |
+
uint elem,
|
201 |
+
device const int* shape,
|
202 |
+
device const size_t* strides) {
|
203 |
+
uint loc = (elem % shape[2]) * strides[2];
|
204 |
+
elem /= shape[2];
|
205 |
+
loc += (elem % shape[1]) * strides[1];
|
206 |
+
elem /= shape[1];
|
207 |
+
loc += (elem % shape[0]) * strides[0];
|
208 |
+
return loc;
|
209 |
+
}
|
210 |
+
|
211 |
+
template <>
|
212 |
+
inline uint elem_to_loc_nd<4>(
|
213 |
+
uint elem,
|
214 |
+
device const int* shape,
|
215 |
+
device const size_t* strides) {
|
216 |
+
uint loc = (elem % shape[3]) * strides[3];
|
217 |
+
elem /= shape[3];
|
218 |
+
loc += (elem % shape[2]) * strides[2];
|
219 |
+
elem /= shape[2];
|
220 |
+
loc += (elem % shape[1]) * strides[1];
|
221 |
+
elem /= shape[1];
|
222 |
+
loc += (elem % shape[0]) * strides[0];
|
223 |
+
return loc;
|
224 |
+
}
|
225 |
+
|
226 |
+
///////////////////////////////////////////////////////////////////////////////
|
227 |
+
// Calculation utils
|
228 |
+
///////////////////////////////////////////////////////////////////////////////
|
229 |
+
|
230 |
+
/** Compute ceil((float)N/(float)M) */
|
231 |
+
inline size_t ceildiv(size_t N, size_t M) {
|
232 |
+
return (N + M - 1) / M;
|
233 |
+
}
|
234 |
+
|
235 |
+
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
|
236 |
+
inline float log1p(float x) {
|
237 |
+
float xp1 = 1.0f + x;
|
238 |
+
return (xp1 == 1.0f) ? x : x * (metal::log(xp1) / (xp1 - 1.0f));
|
239 |
+
}
|
240 |
+
|
241 |
+
inline bfloat16_t log1p(bfloat16_t x) {
|
242 |
+
float xp1 = 1.0f + static_cast<float>(x);
|
243 |
+
bfloat16_t ret =
|
244 |
+
(xp1 == 1.0f) ? x : bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
245 |
+
return ret;
|
246 |
+
}
|
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/matmul.h
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#include <algorithm>
|
4 |
+
#include <cassert>
|
5 |
+
#include <sstream>
|
6 |
+
|
7 |
+
#include "mlx/backend/metal/copy.h"
|
8 |
+
#include "mlx/backend/metal/device.h"
|
9 |
+
#include "mlx/backend/metal/mps/gemm.h"
|
10 |
+
#include "mlx/backend/metal/utils.h"
|
11 |
+
#include "mlx/utils.h"
|
12 |
+
|
13 |
+
namespace mlx::core {
|
14 |
+
|
15 |
+
void mlx_matmul(
|
16 |
+
const Stream& s,
|
17 |
+
metal::Device& d,
|
18 |
+
const array& a,
|
19 |
+
const array& b,
|
20 |
+
array& out,
|
21 |
+
int M,
|
22 |
+
int N,
|
23 |
+
int K,
|
24 |
+
int batch_size_out,
|
25 |
+
int lda,
|
26 |
+
int ldb,
|
27 |
+
bool transpose_a,
|
28 |
+
bool transpose_b,
|
29 |
+
std::vector<array>& copies);
|
30 |
+
|
31 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/metal.h
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include <future>
|
6 |
+
#include <memory>
|
7 |
+
#include <vector>
|
8 |
+
|
9 |
+
#include "mlx/array.h"
|
10 |
+
#include "mlx/stream.h"
|
11 |
+
|
12 |
+
namespace mlx::core::metal {
|
13 |
+
|
14 |
+
constexpr bool is_available() {
|
15 |
+
#ifdef _METAL_
|
16 |
+
return true;
|
17 |
+
#else
|
18 |
+
return false;
|
19 |
+
#endif
|
20 |
+
}
|
21 |
+
|
22 |
+
void new_stream(Stream stream);
|
23 |
+
std::shared_ptr<void> new_scoped_memory_pool();
|
24 |
+
|
25 |
+
std::function<void()> make_task(
|
26 |
+
array& arr,
|
27 |
+
std::vector<std::shared_future<void>> deps,
|
28 |
+
std::shared_ptr<std::promise<void>> p,
|
29 |
+
bool retain_graph);
|
30 |
+
|
31 |
+
} // namespace mlx::core::metal
|
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/mps/gemm.h
ADDED
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include <Metal/Metal.hpp>
|
6 |
+
|
7 |
+
#define _MPS_PRIVATE_CLS(symbol) (MTL::Private::Class::s_k##symbol)
|
8 |
+
#define _MPS_PRIVATE_SEL(accessor) (MTL::Private::Selector::s_k##accessor)
|
9 |
+
|
10 |
+
namespace MTL::Private::Class {
|
11 |
+
_MTL_PRIVATE_DEF_CLS(MPSMatrixDescriptor);
|
12 |
+
_MTL_PRIVATE_DEF_CLS(MPSMatrix);
|
13 |
+
_MTL_PRIVATE_DEF_CLS(MPSVectorDescriptor);
|
14 |
+
_MTL_PRIVATE_DEF_CLS(MPSVector);
|
15 |
+
_MTL_PRIVATE_DEF_CLS(MPSKernel);
|
16 |
+
_MTL_PRIVATE_DEF_CLS(MPSMatrixMultiplication);
|
17 |
+
_MTL_PRIVATE_DEF_CLS(MPSMatrixVectorMultiplication);
|
18 |
+
} // namespace MTL::Private::Class
|
19 |
+
|
20 |
+
namespace MTL::Private::Selector {
|
21 |
+
_MTL_PRIVATE_DEF_SEL(
|
22 |
+
matrixDescriptorWithRows_columns_rowBytes_dataType,
|
23 |
+
"matrixDescriptorWithRows:columns:rowBytes:dataType:");
|
24 |
+
_MTL_PRIVATE_DEF_SEL(
|
25 |
+
matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType,
|
26 |
+
"matrixDescriptorWithRows:columns:matrices:rowBytes:matrixBytes:dataType:");
|
27 |
+
_MTL_PRIVATE_DEF_SEL(rows, "rows");
|
28 |
+
_MTL_PRIVATE_DEF_SEL(initWithBuffer_descriptor, "initWithBuffer:descriptor:");
|
29 |
+
_MTL_PRIVATE_DEF_SEL(
|
30 |
+
initWithDevice_,
|
31 |
+
"initWithDevice:transposeLeft:transposeRight:"
|
32 |
+
"resultRows:resultColumns:interiorColumns:alpha:beta:");
|
33 |
+
_MTL_PRIVATE_DEF_SEL(
|
34 |
+
encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix,
|
35 |
+
"encodeToCommandBuffer:leftMatrix:rightMatrix:resultMatrix:");
|
36 |
+
_MTL_PRIVATE_DEF_SEL(setLeftMatrixOrigin_, "setLeftMatrixOrigin:");
|
37 |
+
_MTL_PRIVATE_DEF_SEL(setRightMatrixOrigin_, "setRightMatrixOrigin:");
|
38 |
+
_MTL_PRIVATE_DEF_SEL(setResultMatrixOrigin_, "setResultMatrixOrigin:");
|
39 |
+
_MTL_PRIVATE_DEF_SEL(setBatchStart_, "setBatchStart:");
|
40 |
+
_MTL_PRIVATE_DEF_SEL(setBatchSize_, "setBatchSize:");
|
41 |
+
_MTL_PRIVATE_DEF_SEL(
|
42 |
+
vectorDescriptorWithLength_dataType,
|
43 |
+
"vectorDescriptorWithLength:dataType:");
|
44 |
+
_MTL_PRIVATE_DEF_SEL(
|
45 |
+
vectorDescriptorWithLength_vectors_vectorBytes_dataType,
|
46 |
+
"vectorDescriptorWithLength:vectors:vectorBytes:dataType:");
|
47 |
+
_MTL_PRIVATE_DEF_SEL(
|
48 |
+
initWithDevice_transpose_rows_columns_alpha_beta,
|
49 |
+
"initWithDevice:transpose:rows:columns:alpha:beta:");
|
50 |
+
_MTL_PRIVATE_DEF_SEL(
|
51 |
+
encodeToCommandBuffer_inputMatrix_inputVector_resultVector,
|
52 |
+
"encodeToCommandBuffer:inputMatrix:inputVector:resultVector:");
|
53 |
+
} // namespace MTL::Private::Selector
|
54 |
+
|
55 |
+
namespace MPS {
|
56 |
+
|
57 |
+
typedef enum DataType : uint32_t {
|
58 |
+
DataTypeFloatBit = 0x10000000,
|
59 |
+
DataTypeAlternateEncodingBit = 0x80000000,
|
60 |
+
DataTypeFloat16 = DataTypeFloatBit | 16,
|
61 |
+
DataTypeFloat32 = DataTypeFloatBit | 32,
|
62 |
+
DataTypeBFloat16 = DataTypeAlternateEncodingBit | DataTypeFloat16
|
63 |
+
} DataType;
|
64 |
+
|
65 |
+
class MatrixDescriptor : public NS::Copying<MatrixDescriptor> {
|
66 |
+
public:
|
67 |
+
static class MatrixDescriptor* matrixDescriptor(
|
68 |
+
NS::UInteger rows,
|
69 |
+
NS::UInteger columns,
|
70 |
+
NS::UInteger rowBytes,
|
71 |
+
NS::UInteger dataType);
|
72 |
+
static class MatrixDescriptor* matrixDescriptor(
|
73 |
+
NS::UInteger rows,
|
74 |
+
NS::UInteger columns,
|
75 |
+
NS::UInteger matrices,
|
76 |
+
NS::UInteger rowBytes,
|
77 |
+
NS::UInteger matrixBytes,
|
78 |
+
NS::UInteger dataType);
|
79 |
+
NS::UInteger rows() const;
|
80 |
+
};
|
81 |
+
|
82 |
+
class Matrix : public NS::Referencing<Matrix> {
|
83 |
+
public:
|
84 |
+
static class Matrix* alloc();
|
85 |
+
Matrix* init(MTL::Buffer* buffer, MatrixDescriptor* descriptor);
|
86 |
+
Matrix* init(const MTL::Buffer* buffer, MatrixDescriptor* descriptor);
|
87 |
+
};
|
88 |
+
|
89 |
+
class Kernel : public NS::Referencing<Kernel> {
|
90 |
+
public:
|
91 |
+
NS::String* label() const;
|
92 |
+
MTL::Device* device() const;
|
93 |
+
};
|
94 |
+
|
95 |
+
class MatrixMultiplication
|
96 |
+
: public NS::Referencing<MatrixMultiplication, Kernel> {
|
97 |
+
public:
|
98 |
+
static class MatrixMultiplication* alloc();
|
99 |
+
|
100 |
+
MatrixMultiplication* init(
|
101 |
+
MTL::Device* device,
|
102 |
+
bool transposeLeft,
|
103 |
+
bool transposeRight,
|
104 |
+
NS::UInteger resultRows,
|
105 |
+
NS::UInteger resultColumns,
|
106 |
+
NS::UInteger interiorColumns,
|
107 |
+
double alpha,
|
108 |
+
double beta);
|
109 |
+
|
110 |
+
void encodeToCommandBuffer(
|
111 |
+
MTL::CommandBuffer* commandBuffer,
|
112 |
+
Matrix* leftMatrix,
|
113 |
+
Matrix* rightMatrix,
|
114 |
+
Matrix* resultMatrix);
|
115 |
+
|
116 |
+
void setLeftMatrixOrigin(MTL::Origin origin);
|
117 |
+
void setRightMatrixOrigin(MTL::Origin origin);
|
118 |
+
void setResultMatrixOrigin(MTL::Origin origin);
|
119 |
+
void setBatchStart(NS::UInteger batchStart);
|
120 |
+
void setBatchSize(NS::UInteger batchSize);
|
121 |
+
};
|
122 |
+
|
123 |
+
class VectorDescriptor : public NS::Copying<VectorDescriptor> {
|
124 |
+
public:
|
125 |
+
static class VectorDescriptor* vectorDescriptor(
|
126 |
+
NS::UInteger length,
|
127 |
+
NS::UInteger dataType);
|
128 |
+
static class VectorDescriptor* vectorDescriptor(
|
129 |
+
NS::UInteger length,
|
130 |
+
NS::UInteger vectors,
|
131 |
+
NS::UInteger vectorBytes,
|
132 |
+
NS::UInteger dataType);
|
133 |
+
};
|
134 |
+
|
135 |
+
class Vector : public NS::Referencing<Vector> {
|
136 |
+
public:
|
137 |
+
static class Vector* alloc();
|
138 |
+
Vector* init(MTL::Buffer* buffer, VectorDescriptor* descriptor);
|
139 |
+
Vector* init(const MTL::Buffer* buffer, VectorDescriptor* descriptor);
|
140 |
+
};
|
141 |
+
|
142 |
+
class MatrixVectorMultiplication
|
143 |
+
: public NS::Referencing<MatrixVectorMultiplication, Kernel> {
|
144 |
+
public:
|
145 |
+
static class MatrixVectorMultiplication* alloc();
|
146 |
+
|
147 |
+
MatrixVectorMultiplication* init(
|
148 |
+
MTL::Device* device,
|
149 |
+
bool transpose,
|
150 |
+
NS::UInteger rows,
|
151 |
+
NS::UInteger columns,
|
152 |
+
double alpha,
|
153 |
+
double beta);
|
154 |
+
|
155 |
+
void encodeToCommandBuffer(
|
156 |
+
MTL::CommandBuffer* commandBuffer,
|
157 |
+
Matrix* inputMatrix,
|
158 |
+
Vector* inputVector,
|
159 |
+
Vector* resultVector);
|
160 |
+
};
|
161 |
+
|
162 |
+
_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
|
163 |
+
NS::UInteger rows,
|
164 |
+
NS::UInteger columns,
|
165 |
+
NS::UInteger rowBytes,
|
166 |
+
NS::UInteger dataType) {
|
167 |
+
return Object::sendMessage<MatrixDescriptor*>(
|
168 |
+
_MPS_PRIVATE_CLS(MPSMatrixDescriptor),
|
169 |
+
_MPS_PRIVATE_SEL(matrixDescriptorWithRows_columns_rowBytes_dataType),
|
170 |
+
rows,
|
171 |
+
columns,
|
172 |
+
rowBytes,
|
173 |
+
dataType);
|
174 |
+
}
|
175 |
+
|
176 |
+
_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
|
177 |
+
NS::UInteger rows,
|
178 |
+
NS::UInteger columns,
|
179 |
+
NS::UInteger matrices,
|
180 |
+
NS::UInteger rowBytes,
|
181 |
+
NS::UInteger matrixBytes,
|
182 |
+
NS::UInteger dataType) {
|
183 |
+
return Object::sendMessage<MatrixDescriptor*>(
|
184 |
+
_MPS_PRIVATE_CLS(MPSMatrixDescriptor),
|
185 |
+
_MPS_PRIVATE_SEL(
|
186 |
+
matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType),
|
187 |
+
rows,
|
188 |
+
columns,
|
189 |
+
matrices,
|
190 |
+
rowBytes,
|
191 |
+
matrixBytes,
|
192 |
+
dataType);
|
193 |
+
}
|
194 |
+
|
195 |
+
_MTL_INLINE NS::UInteger MatrixDescriptor::rows() const {
|
196 |
+
return Object::sendMessage<NS::UInteger>(this, _MPS_PRIVATE_SEL(rows));
|
197 |
+
}
|
198 |
+
|
199 |
+
_MTL_INLINE Matrix* Matrix::alloc() {
|
200 |
+
return NS::Object::alloc<Matrix>(_MPS_PRIVATE_CLS(MPSMatrix));
|
201 |
+
}
|
202 |
+
|
203 |
+
_MTL_INLINE Matrix* Matrix::init(
|
204 |
+
MTL::Buffer* buffer,
|
205 |
+
MatrixDescriptor* descriptor) {
|
206 |
+
return Object::sendMessage<Matrix*>(
|
207 |
+
this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
|
208 |
+
}
|
209 |
+
|
210 |
+
_MTL_INLINE Matrix* Matrix::init(
|
211 |
+
const MTL::Buffer* buffer,
|
212 |
+
MatrixDescriptor* descriptor) {
|
213 |
+
return init(const_cast<MTL::Buffer*>(buffer), descriptor);
|
214 |
+
}
|
215 |
+
|
216 |
+
_MTL_INLINE NS::String* Kernel::label() const {
|
217 |
+
return Object::sendMessage<NS::String*>(this, _MPS_PRIVATE_SEL(label));
|
218 |
+
}
|
219 |
+
|
220 |
+
_MTL_INLINE MTL::Device* Kernel::device() const {
|
221 |
+
return Object::sendMessage<MTL::Device*>(this, _MPS_PRIVATE_SEL(device));
|
222 |
+
}
|
223 |
+
|
224 |
+
_MTL_INLINE MatrixMultiplication* MatrixMultiplication::alloc() {
|
225 |
+
return NS::Object::alloc<MatrixMultiplication>(
|
226 |
+
_MPS_PRIVATE_CLS(MPSMatrixMultiplication));
|
227 |
+
}
|
228 |
+
|
229 |
+
_MTL_INLINE MatrixMultiplication* MatrixMultiplication::init(
|
230 |
+
MTL::Device* device,
|
231 |
+
bool transposeLeft,
|
232 |
+
bool transposeRight,
|
233 |
+
NS::UInteger resultRows,
|
234 |
+
NS::UInteger resultColumns,
|
235 |
+
NS::UInteger interiorColumns,
|
236 |
+
double alpha,
|
237 |
+
double beta) {
|
238 |
+
return Object::sendMessage<MatrixMultiplication*>(
|
239 |
+
this,
|
240 |
+
_MPS_PRIVATE_SEL(initWithDevice_),
|
241 |
+
device,
|
242 |
+
transposeLeft,
|
243 |
+
transposeRight,
|
244 |
+
resultRows,
|
245 |
+
resultColumns,
|
246 |
+
interiorColumns,
|
247 |
+
alpha,
|
248 |
+
beta);
|
249 |
+
}
|
250 |
+
|
251 |
+
_MTL_INLINE void MatrixMultiplication::encodeToCommandBuffer(
|
252 |
+
MTL::CommandBuffer* commandBuffer,
|
253 |
+
Matrix* leftMatrix,
|
254 |
+
Matrix* rightMatrix,
|
255 |
+
Matrix* resultMatrix) {
|
256 |
+
return Object::sendMessage<void>(
|
257 |
+
this,
|
258 |
+
_MPS_PRIVATE_SEL(
|
259 |
+
encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix),
|
260 |
+
commandBuffer,
|
261 |
+
leftMatrix,
|
262 |
+
rightMatrix,
|
263 |
+
resultMatrix);
|
264 |
+
}
|
265 |
+
|
266 |
+
_MTL_INLINE void MatrixMultiplication::setLeftMatrixOrigin(MTL::Origin origin) {
|
267 |
+
Object::sendMessage<void>(
|
268 |
+
this, _MPS_PRIVATE_SEL(setLeftMatrixOrigin_), origin);
|
269 |
+
}
|
270 |
+
|
271 |
+
_MTL_INLINE void MatrixMultiplication::setRightMatrixOrigin(
|
272 |
+
MTL::Origin origin) {
|
273 |
+
Object::sendMessage<void>(
|
274 |
+
this, _MPS_PRIVATE_SEL(setRightMatrixOrigin_), origin);
|
275 |
+
}
|
276 |
+
|
277 |
+
_MTL_INLINE void MatrixMultiplication::setResultMatrixOrigin(
|
278 |
+
MTL::Origin origin) {
|
279 |
+
Object::sendMessage<void>(
|
280 |
+
this, _MPS_PRIVATE_SEL(setResultMatrixOrigin_), origin);
|
281 |
+
}
|
282 |
+
|
283 |
+
_MTL_INLINE void MatrixMultiplication::setBatchStart(NS::UInteger batchStart) {
|
284 |
+
Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchStart_), batchStart);
|
285 |
+
}
|
286 |
+
|
287 |
+
_MTL_INLINE void MatrixMultiplication::setBatchSize(NS::UInteger batchSize) {
|
288 |
+
Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchSize_), batchSize);
|
289 |
+
}
|
290 |
+
|
291 |
+
_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
|
292 |
+
NS::UInteger length,
|
293 |
+
NS::UInteger dataType) {
|
294 |
+
return Object::sendMessage<VectorDescriptor*>(
|
295 |
+
_MPS_PRIVATE_CLS(MPSVectorDescriptor),
|
296 |
+
_MPS_PRIVATE_SEL(vectorDescriptorWithLength_dataType),
|
297 |
+
length,
|
298 |
+
dataType);
|
299 |
+
}
|
300 |
+
|
301 |
+
_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
|
302 |
+
NS::UInteger length,
|
303 |
+
NS::UInteger vectors,
|
304 |
+
NS::UInteger vectorBytes,
|
305 |
+
NS::UInteger dataType) {
|
306 |
+
return Object::sendMessage<VectorDescriptor*>(
|
307 |
+
_MPS_PRIVATE_CLS(MPSVectorDescriptor),
|
308 |
+
_MPS_PRIVATE_SEL(vectorDescriptorWithLength_vectors_vectorBytes_dataType),
|
309 |
+
length,
|
310 |
+
vectors,
|
311 |
+
vectorBytes,
|
312 |
+
dataType);
|
313 |
+
}
|
314 |
+
|
315 |
+
_MTL_INLINE Vector* Vector::alloc() {
|
316 |
+
return NS::Object::alloc<Vector>(_MPS_PRIVATE_CLS(MPSVector));
|
317 |
+
}
|
318 |
+
|
319 |
+
_MTL_INLINE Vector* Vector::init(
|
320 |
+
MTL::Buffer* buffer,
|
321 |
+
VectorDescriptor* descriptor) {
|
322 |
+
return Object::sendMessage<Vector*>(
|
323 |
+
this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
|
324 |
+
}
|
325 |
+
|
326 |
+
_MTL_INLINE Vector* Vector::init(
|
327 |
+
const MTL::Buffer* buffer,
|
328 |
+
VectorDescriptor* descriptor) {
|
329 |
+
return init(const_cast<MTL::Buffer*>(buffer), descriptor);
|
330 |
+
}
|
331 |
+
|
332 |
+
_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::alloc() {
|
333 |
+
return NS::Object::alloc<MatrixVectorMultiplication>(
|
334 |
+
_MPS_PRIVATE_CLS(MPSMatrixVectorMultiplication));
|
335 |
+
}
|
336 |
+
|
337 |
+
_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::init(
|
338 |
+
MTL::Device* device,
|
339 |
+
bool transpose,
|
340 |
+
NS::UInteger rows,
|
341 |
+
NS::UInteger columns,
|
342 |
+
double alpha,
|
343 |
+
double beta) {
|
344 |
+
return Object::sendMessage<MatrixVectorMultiplication*>(
|
345 |
+
this,
|
346 |
+
_MPS_PRIVATE_SEL(initWithDevice_transpose_rows_columns_alpha_beta),
|
347 |
+
device,
|
348 |
+
transpose,
|
349 |
+
rows,
|
350 |
+
columns,
|
351 |
+
alpha,
|
352 |
+
beta);
|
353 |
+
}
|
354 |
+
|
355 |
+
_MTL_INLINE void MatrixVectorMultiplication::encodeToCommandBuffer(
|
356 |
+
MTL::CommandBuffer* commandBuffer,
|
357 |
+
Matrix* inputMatrix,
|
358 |
+
Vector* inputVector,
|
359 |
+
Vector* resultVector) {
|
360 |
+
return Object::sendMessage<void>(
|
361 |
+
this,
|
362 |
+
_MPS_PRIVATE_SEL(
|
363 |
+
encodeToCommandBuffer_inputMatrix_inputVector_resultVector),
|
364 |
+
commandBuffer,
|
365 |
+
inputMatrix,
|
366 |
+
inputVector,
|
367 |
+
resultVector);
|
368 |
+
}
|
369 |
+
|
370 |
+
} // namespace MPS
|
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/utils.h
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include "mlx/array.h"
|
6 |
+
#include "mlx/backend/metal/device.h"
|
7 |
+
|
8 |
+
namespace mlx::core {
|
9 |
+
|
10 |
+
namespace {
|
11 |
+
|
12 |
+
void set_array_buffer(
|
13 |
+
MTL::ComputeCommandEncoder* compute_encoder,
|
14 |
+
MTL::ArgumentEncoder* enc,
|
15 |
+
const array& a,
|
16 |
+
int idx) {
|
17 |
+
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
18 |
+
auto offset = a.data<char>() -
|
19 |
+
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
20 |
+
enc->setBuffer(a_buf, offset, idx);
|
21 |
+
// MTL::Resource usage through argument buffer needs to be explicitly
|
22 |
+
// flagged to enable hazard tracking
|
23 |
+
compute_encoder->useResource(a_buf, MTL::ResourceUsageRead);
|
24 |
+
}
|
25 |
+
|
26 |
+
void set_array_buffer(
|
27 |
+
MTL::ComputeCommandEncoder* enc,
|
28 |
+
const array& a,
|
29 |
+
int idx) {
|
30 |
+
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
31 |
+
auto offset = a.data<char>() -
|
32 |
+
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
33 |
+
enc->setBuffer(a_buf, offset, idx);
|
34 |
+
}
|
35 |
+
|
36 |
+
std::string type_to_name(const array& a) {
|
37 |
+
std::string tname;
|
38 |
+
switch (a.dtype()) {
|
39 |
+
case bool_:
|
40 |
+
tname = "bool_";
|
41 |
+
break;
|
42 |
+
case uint8:
|
43 |
+
tname = "uint8";
|
44 |
+
break;
|
45 |
+
case uint16:
|
46 |
+
tname = "uint16";
|
47 |
+
break;
|
48 |
+
case uint32:
|
49 |
+
tname = "uint32";
|
50 |
+
break;
|
51 |
+
case uint64:
|
52 |
+
tname = "uint64";
|
53 |
+
break;
|
54 |
+
case int8:
|
55 |
+
tname = "int8";
|
56 |
+
break;
|
57 |
+
case int16:
|
58 |
+
tname = "int16";
|
59 |
+
break;
|
60 |
+
case int32:
|
61 |
+
tname = "int32";
|
62 |
+
break;
|
63 |
+
case int64:
|
64 |
+
tname = "int64";
|
65 |
+
break;
|
66 |
+
case float16:
|
67 |
+
tname = "float16";
|
68 |
+
break;
|
69 |
+
case float32:
|
70 |
+
tname = "float32";
|
71 |
+
break;
|
72 |
+
case bfloat16:
|
73 |
+
tname = "bfloat16";
|
74 |
+
break;
|
75 |
+
case complex64:
|
76 |
+
tname = "complex64";
|
77 |
+
break;
|
78 |
+
}
|
79 |
+
return tname;
|
80 |
+
}
|
81 |
+
|
82 |
+
MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
83 |
+
int pows[3] = {0, 0, 0};
|
84 |
+
int sum = 0;
|
85 |
+
while (true) {
|
86 |
+
int presum = sum;
|
87 |
+
// Check all the pows
|
88 |
+
if (dim0 >= (1 << (pows[0] + 1))) {
|
89 |
+
pows[0]++;
|
90 |
+
sum++;
|
91 |
+
}
|
92 |
+
if (sum == 10) {
|
93 |
+
break;
|
94 |
+
}
|
95 |
+
if (dim1 >= (1 << (pows[1] + 1))) {
|
96 |
+
pows[1]++;
|
97 |
+
sum++;
|
98 |
+
}
|
99 |
+
if (sum == 10) {
|
100 |
+
break;
|
101 |
+
}
|
102 |
+
if (dim2 >= (1 << (pows[2] + 1))) {
|
103 |
+
pows[2]++;
|
104 |
+
sum++;
|
105 |
+
}
|
106 |
+
if (sum == presum || sum == 10) {
|
107 |
+
break;
|
108 |
+
}
|
109 |
+
}
|
110 |
+
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
|
111 |
+
}
|
112 |
+
|
113 |
+
// Collapse dims that are contiguous to possibly route to a better kernel
|
114 |
+
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
|
115 |
+
// should return {{2, 4}, {{1, 2}}}.
|
116 |
+
//
|
117 |
+
// When multiple arrays are passed they should all have the same shape. The
|
118 |
+
// collapsed axes are also the same so one shape is returned.
|
119 |
+
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
120 |
+
collapse_contiguous_dims(const std::vector<array>& xs) {
|
121 |
+
// Make a vector that has axes separated with -1. Collapse all axes between
|
122 |
+
// -1.
|
123 |
+
std::vector<int> to_collapse;
|
124 |
+
if (xs[0].ndim() > 0) {
|
125 |
+
to_collapse.push_back(0);
|
126 |
+
for (int i = 1; i < xs[0].ndim(); i++) {
|
127 |
+
bool contiguous = true;
|
128 |
+
for (auto& x : xs) {
|
129 |
+
if (x.strides()[i] * x.shape()[i] != x.strides()[i - 1]) {
|
130 |
+
contiguous = false;
|
131 |
+
}
|
132 |
+
if (!contiguous) {
|
133 |
+
break;
|
134 |
+
}
|
135 |
+
}
|
136 |
+
if (!contiguous) {
|
137 |
+
to_collapse.push_back(-1);
|
138 |
+
}
|
139 |
+
to_collapse.push_back(i);
|
140 |
+
}
|
141 |
+
to_collapse.push_back(-1);
|
142 |
+
}
|
143 |
+
|
144 |
+
std::vector<int> out_shape;
|
145 |
+
std::vector<std::vector<size_t>> out_strides(xs.size());
|
146 |
+
for (int i = 0; i < to_collapse.size(); i++) {
|
147 |
+
int current_shape = xs[0].shape()[to_collapse[i]];
|
148 |
+
while (to_collapse[++i] != -1) {
|
149 |
+
current_shape *= xs[0].shape()[to_collapse[i]];
|
150 |
+
}
|
151 |
+
out_shape.push_back(current_shape);
|
152 |
+
for (int j = 0; j < xs.size(); j++) {
|
153 |
+
out_strides[j].push_back(xs[j].strides()[to_collapse[i - 1]]);
|
154 |
+
}
|
155 |
+
}
|
156 |
+
|
157 |
+
return std::make_tuple(out_shape, out_strides);
|
158 |
+
}
|
159 |
+
|
160 |
+
template <typename... Arrays>
|
161 |
+
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
162 |
+
collapse_contiguous_dims(Arrays... xs) {
|
163 |
+
return collapse_contiguous_dims(
|
164 |
+
std::vector<array>{std::forward<Arrays>(xs)...});
|
165 |
+
}
|
166 |
+
|
167 |
+
} // namespace
|
168 |
+
|
169 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/device.h
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
namespace mlx::core {
|
6 |
+
|
7 |
+
struct Device {
|
8 |
+
enum class DeviceType {
|
9 |
+
cpu,
|
10 |
+
gpu,
|
11 |
+
};
|
12 |
+
|
13 |
+
static constexpr DeviceType cpu = DeviceType::cpu;
|
14 |
+
static constexpr DeviceType gpu = DeviceType::gpu;
|
15 |
+
|
16 |
+
Device(DeviceType type, int index = 0) : type(type), index(index){};
|
17 |
+
|
18 |
+
DeviceType type;
|
19 |
+
int index;
|
20 |
+
};
|
21 |
+
|
22 |
+
const Device& default_device();
|
23 |
+
|
24 |
+
void set_default_device(const Device& d);
|
25 |
+
|
26 |
+
bool operator==(const Device& lhs, const Device& rhs);
|
27 |
+
bool operator!=(const Device& lhs, const Device& rhs);
|
28 |
+
|
29 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/dtype.h
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include <complex>
|
6 |
+
#include <cstdint>
|
7 |
+
#include <ostream>
|
8 |
+
#include <string>
|
9 |
+
|
10 |
+
#include "mlx/types/complex.h"
|
11 |
+
#include "mlx/types/half_types.h"
|
12 |
+
|
13 |
+
namespace mlx::core {
|
14 |
+
|
15 |
+
struct Dtype {
|
16 |
+
enum class Val {
|
17 |
+
bool_,
|
18 |
+
uint8,
|
19 |
+
uint16,
|
20 |
+
uint32,
|
21 |
+
uint64,
|
22 |
+
int8,
|
23 |
+
int16,
|
24 |
+
int32,
|
25 |
+
int64,
|
26 |
+
float16,
|
27 |
+
float32,
|
28 |
+
bfloat16,
|
29 |
+
complex64,
|
30 |
+
};
|
31 |
+
|
32 |
+
enum class Kind {
|
33 |
+
b, /* bool */
|
34 |
+
u, /* unsigned int */
|
35 |
+
i, /* signed int */
|
36 |
+
f, /* float */
|
37 |
+
c, /* complex */
|
38 |
+
V, /* void - used for brain float */
|
39 |
+
};
|
40 |
+
|
41 |
+
Val val;
|
42 |
+
const uint8_t size;
|
43 |
+
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){};
|
44 |
+
constexpr operator Val() const {
|
45 |
+
return val;
|
46 |
+
};
|
47 |
+
};
|
48 |
+
|
49 |
+
inline bool is_available(const Dtype& dtype) {
|
50 |
+
return true;
|
51 |
+
}
|
52 |
+
|
53 |
+
static constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
|
54 |
+
|
55 |
+
static constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)};
|
56 |
+
static constexpr Dtype uint16{Dtype::Val::uint16, sizeof(uint16_t)};
|
57 |
+
static constexpr Dtype uint32{Dtype::Val::uint32, sizeof(uint32_t)};
|
58 |
+
static constexpr Dtype uint64{Dtype::Val::uint64, sizeof(uint64_t)};
|
59 |
+
|
60 |
+
static constexpr Dtype int8{Dtype::Val::int8, sizeof(int8_t)};
|
61 |
+
static constexpr Dtype int16{Dtype::Val::int16, sizeof(int16_t)};
|
62 |
+
static constexpr Dtype int32{Dtype::Val::int32, sizeof(int32_t)};
|
63 |
+
static constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)};
|
64 |
+
|
65 |
+
static constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)};
|
66 |
+
static constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
|
67 |
+
static constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
|
68 |
+
static constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};
|
69 |
+
|
70 |
+
Dtype promote_types(const Dtype& t1, const Dtype& t2);
|
71 |
+
|
72 |
+
inline uint8_t size_of(const Dtype& t) {
|
73 |
+
return t.size;
|
74 |
+
}
|
75 |
+
|
76 |
+
Dtype::Kind kindof(const Dtype& t);
|
77 |
+
|
78 |
+
inline bool is_unsigned(const Dtype& t) {
|
79 |
+
return kindof(t) == Dtype::Kind::u || kindof(t) == Dtype::Kind::b;
|
80 |
+
}
|
81 |
+
|
82 |
+
inline bool is_floating_point(const Dtype& t) {
|
83 |
+
return kindof(t) == Dtype::Kind::f || kindof(t) == Dtype::Kind::V ||
|
84 |
+
kindof(t) == Dtype::Kind::c;
|
85 |
+
}
|
86 |
+
|
87 |
+
inline bool is_complex(const Dtype& t) {
|
88 |
+
return kindof(t) == Dtype::Kind::c;
|
89 |
+
}
|
90 |
+
|
91 |
+
inline bool is_integral(const Dtype& t) {
|
92 |
+
return !(is_floating_point(t));
|
93 |
+
}
|
94 |
+
|
95 |
+
template <typename T>
|
96 |
+
struct TypeToDtype {
|
97 |
+
operator Dtype();
|
98 |
+
};
|
99 |
+
|
100 |
+
// Array protocol typestring for Dtype
|
101 |
+
std::string dtype_to_array_protocol(const Dtype& t);
|
102 |
+
// Dtype from array protocol type string
|
103 |
+
Dtype dtype_from_array_protocol(const std::string& t);
|
104 |
+
|
105 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/fft.h
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include <variant>
|
6 |
+
|
7 |
+
#include "array.h"
|
8 |
+
#include "device.h"
|
9 |
+
#include "stream.h"
|
10 |
+
|
11 |
+
namespace mlx::core::fft {
|
12 |
+
|
13 |
+
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
|
14 |
+
|
15 |
+
/** Compute the n-dimensional Fourier Transform. */
|
16 |
+
array fftn(
|
17 |
+
const array& a,
|
18 |
+
const std::vector<int>& n,
|
19 |
+
const std::vector<int>& axes,
|
20 |
+
StreamOrDevice s = {});
|
21 |
+
array fftn(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});
|
22 |
+
array fftn(const array& a, StreamOrDevice s = {});
|
23 |
+
|
24 |
+
/** Compute the n-dimensional inverse Fourier Transform. */
|
25 |
+
array ifftn(
|
26 |
+
const array& a,
|
27 |
+
const std::vector<int>& n,
|
28 |
+
const std::vector<int>& axes,
|
29 |
+
StreamOrDevice s = {});
|
30 |
+
array ifftn(
|
31 |
+
const array& a,
|
32 |
+
const std::vector<int>& axes,
|
33 |
+
StreamOrDevice s = {});
|
34 |
+
array ifftn(const array& a, StreamOrDevice s = {});
|
35 |
+
|
36 |
+
/** Compute the one-dimensional Fourier Transform. */
|
37 |
+
inline array fft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
38 |
+
return fftn(a, {n}, {axis}, s);
|
39 |
+
}
|
40 |
+
inline array fft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
41 |
+
return fftn(a, {axis}, s);
|
42 |
+
}
|
43 |
+
|
44 |
+
/** Compute the one-dimensional inverse Fourier Transform. */
|
45 |
+
inline array ifft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
46 |
+
return ifftn(a, {n}, {axis}, s);
|
47 |
+
}
|
48 |
+
inline array ifft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
49 |
+
return ifftn(a, {axis}, s);
|
50 |
+
}
|
51 |
+
|
52 |
+
/** Compute the two-dimensional Fourier Transform. */
|
53 |
+
inline array fft2(
|
54 |
+
const array& a,
|
55 |
+
const std::vector<int>& n,
|
56 |
+
const std::vector<int>& axes,
|
57 |
+
StreamOrDevice s = {}) {
|
58 |
+
return fftn(a, n, axes, s);
|
59 |
+
}
|
60 |
+
inline array fft2(
|
61 |
+
const array& a,
|
62 |
+
const std::vector<int>& axes = {-2, -1},
|
63 |
+
StreamOrDevice s = {}) {
|
64 |
+
return fftn(a, axes, s);
|
65 |
+
}
|
66 |
+
|
67 |
+
/** Compute the two-dimensional inverse Fourier Transform. */
|
68 |
+
inline array ifft2(
|
69 |
+
const array& a,
|
70 |
+
const std::vector<int>& n,
|
71 |
+
const std::vector<int>& axes,
|
72 |
+
StreamOrDevice s = {}) {
|
73 |
+
return ifftn(a, n, axes, s);
|
74 |
+
}
|
75 |
+
inline array ifft2(
|
76 |
+
const array& a,
|
77 |
+
const std::vector<int>& axes = {-2, -1},
|
78 |
+
StreamOrDevice s = {}) {
|
79 |
+
return ifftn(a, axes, s);
|
80 |
+
}
|
81 |
+
|
82 |
+
/** Compute the n-dimensional Fourier Transform on a real input. */
|
83 |
+
array rfftn(
|
84 |
+
const array& a,
|
85 |
+
const std::vector<int>& n,
|
86 |
+
const std::vector<int>& axes,
|
87 |
+
StreamOrDevice s = {});
|
88 |
+
array rfftn(
|
89 |
+
const array& a,
|
90 |
+
const std::vector<int>& axes,
|
91 |
+
StreamOrDevice s = {});
|
92 |
+
array rfftn(const array& a, StreamOrDevice s = {});
|
93 |
+
|
94 |
+
/** Compute the n-dimensional inverse of `rfftn`. */
|
95 |
+
array irfftn(
|
96 |
+
const array& a,
|
97 |
+
const std::vector<int>& n,
|
98 |
+
const std::vector<int>& axes,
|
99 |
+
StreamOrDevice s = {});
|
100 |
+
array irfftn(
|
101 |
+
const array& a,
|
102 |
+
const std::vector<int>& axes,
|
103 |
+
StreamOrDevice s = {});
|
104 |
+
array irfftn(const array& a, StreamOrDevice s = {});
|
105 |
+
|
106 |
+
/** Compute the one-dimensional Fourier Transform on a real input. */
|
107 |
+
inline array rfft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
108 |
+
return rfftn(a, {n}, {axis}, s);
|
109 |
+
}
|
110 |
+
inline array rfft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
111 |
+
return rfftn(a, {axis}, s);
|
112 |
+
}
|
113 |
+
/** Compute the one-dimensional inverse of `rfft`. */
|
114 |
+
inline array irfft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
115 |
+
return irfftn(a, {n}, {axis}, s);
|
116 |
+
}
|
117 |
+
inline array irfft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
118 |
+
return irfftn(a, {axis}, s);
|
119 |
+
}
|
120 |
+
|
121 |
+
/** Compute the two-dimensional Fourier Transform on a real input. */
|
122 |
+
inline array rfft2(
|
123 |
+
const array& a,
|
124 |
+
const std::vector<int>& n,
|
125 |
+
const std::vector<int>& axes,
|
126 |
+
StreamOrDevice s = {}) {
|
127 |
+
return rfftn(a, n, axes, s);
|
128 |
+
}
|
129 |
+
inline array rfft2(
|
130 |
+
const array& a,
|
131 |
+
const std::vector<int>& axes = {-2, -1},
|
132 |
+
StreamOrDevice s = {}) {
|
133 |
+
return rfftn(a, axes, s);
|
134 |
+
}
|
135 |
+
|
136 |
+
/** Compute the two-dimensional inverse of `rfft2`. */
|
137 |
+
inline array irfft2(
|
138 |
+
const array& a,
|
139 |
+
const std::vector<int>& n,
|
140 |
+
const std::vector<int>& axes,
|
141 |
+
StreamOrDevice s = {}) {
|
142 |
+
return irfftn(a, n, axes, s);
|
143 |
+
}
|
144 |
+
inline array irfft2(
|
145 |
+
const array& a,
|
146 |
+
const std::vector<int>& axes = {-2, -1},
|
147 |
+
StreamOrDevice s = {}) {
|
148 |
+
return irfftn(a, axes, s);
|
149 |
+
}
|
150 |
+
|
151 |
+
} // namespace mlx::core::fft
|
lib/python3.11/site-packages/mlx/include/mlx/graph_utils.h
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include "mlx/array.h"
|
6 |
+
|
7 |
+
namespace mlx::core {
|
8 |
+
|
9 |
+
void print_graph(std::ostream& os, const std::vector<array>& outputs);
|
10 |
+
|
11 |
+
template <typename... Arrays>
|
12 |
+
void print_graph(std::ostream& os, Arrays... outputs) {
|
13 |
+
print_graph(os, std::vector<array>{std::forward<Arrays>(outputs)...});
|
14 |
+
}
|
15 |
+
|
16 |
+
void export_to_dot(std::ostream& os, const std::vector<array>& outputs);
|
17 |
+
|
18 |
+
template <typename... Arrays>
|
19 |
+
void export_to_dot(std::ostream& os, Arrays... outputs) {
|
20 |
+
export_to_dot(os, std::vector<array>{std::forward<Arrays>(outputs)...});
|
21 |
+
}
|
22 |
+
|
23 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/io/load.h
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include <fstream>
|
6 |
+
#include <istream>
|
7 |
+
#include <memory>
|
8 |
+
|
9 |
+
namespace mlx::core {
|
10 |
+
|
11 |
+
namespace io {
|
12 |
+
|
13 |
+
class Reader {
|
14 |
+
public:
|
15 |
+
virtual bool is_open() const = 0;
|
16 |
+
virtual bool good() const = 0;
|
17 |
+
virtual size_t tell() const = 0;
|
18 |
+
virtual void seek(
|
19 |
+
int64_t off,
|
20 |
+
std::ios_base::seekdir way = std::ios_base::beg) = 0;
|
21 |
+
virtual void read(char* data, size_t n) = 0;
|
22 |
+
virtual std::string label() const = 0;
|
23 |
+
};
|
24 |
+
|
25 |
+
class Writer {
|
26 |
+
public:
|
27 |
+
virtual bool is_open() const = 0;
|
28 |
+
virtual bool good() const = 0;
|
29 |
+
virtual size_t tell() const = 0;
|
30 |
+
virtual void seek(
|
31 |
+
int64_t off,
|
32 |
+
std::ios_base::seekdir way = std::ios_base::beg) = 0;
|
33 |
+
virtual void write(const char* data, size_t n) = 0;
|
34 |
+
virtual std::string label() const = 0;
|
35 |
+
};
|
36 |
+
|
37 |
+
class FileReader : public Reader {
|
38 |
+
public:
|
39 |
+
explicit FileReader(const std::shared_ptr<std::ifstream>& is)
|
40 |
+
: is_(is), label_("stream") {}
|
41 |
+
explicit FileReader(const std::string& file_path)
|
42 |
+
: is_(std::make_shared<std::ifstream>(file_path, std::ios::binary)),
|
43 |
+
label_(file_path) {}
|
44 |
+
|
45 |
+
bool is_open() const override {
|
46 |
+
return is_->is_open();
|
47 |
+
}
|
48 |
+
|
49 |
+
bool good() const override {
|
50 |
+
return is_->good();
|
51 |
+
}
|
52 |
+
|
53 |
+
size_t tell() const override {
|
54 |
+
return is_->tellg();
|
55 |
+
}
|
56 |
+
|
57 |
+
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
58 |
+
override {
|
59 |
+
is_->seekg(off, way);
|
60 |
+
}
|
61 |
+
|
62 |
+
void read(char* data, size_t n) override {
|
63 |
+
is_->read(data, n);
|
64 |
+
}
|
65 |
+
|
66 |
+
std::string label() const override {
|
67 |
+
return "file " + label_;
|
68 |
+
}
|
69 |
+
|
70 |
+
private:
|
71 |
+
std::shared_ptr<std::ifstream> is_;
|
72 |
+
std::string label_;
|
73 |
+
};
|
74 |
+
|
75 |
+
class FileWriter : public Writer {
|
76 |
+
public:
|
77 |
+
explicit FileWriter(const std::shared_ptr<std::ofstream>& is)
|
78 |
+
: os_(is), label_("stream") {}
|
79 |
+
explicit FileWriter(const std::string& file_path)
|
80 |
+
: os_(std::make_shared<std::ofstream>(file_path, std::ios::binary)),
|
81 |
+
label_(file_path) {}
|
82 |
+
|
83 |
+
bool is_open() const override {
|
84 |
+
return os_->is_open();
|
85 |
+
}
|
86 |
+
|
87 |
+
bool good() const override {
|
88 |
+
return os_->good();
|
89 |
+
}
|
90 |
+
|
91 |
+
size_t tell() const override {
|
92 |
+
return os_->tellp();
|
93 |
+
}
|
94 |
+
|
95 |
+
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
96 |
+
override {
|
97 |
+
os_->seekp(off, way);
|
98 |
+
}
|
99 |
+
|
100 |
+
void write(const char* data, size_t n) override {
|
101 |
+
os_->write(data, n);
|
102 |
+
}
|
103 |
+
|
104 |
+
std::string label() const override {
|
105 |
+
return "file " + label_;
|
106 |
+
}
|
107 |
+
|
108 |
+
private:
|
109 |
+
std::shared_ptr<std::ofstream> os_;
|
110 |
+
std::string label_;
|
111 |
+
};
|
112 |
+
|
113 |
+
} // namespace io
|
114 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/io/safetensor.h
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include <json.hpp>
|
6 |
+
|
7 |
+
#include "mlx/io/load.h"
|
8 |
+
#include "mlx/ops.h"
|
9 |
+
#include "mlx/primitives.h"
|
10 |
+
|
11 |
+
using json = nlohmann::json;
|
12 |
+
|
13 |
+
namespace mlx::core {
|
14 |
+
|
15 |
+
#define ST_F16 "F16"
|
16 |
+
#define ST_BF16 "BF16"
|
17 |
+
#define ST_F32 "F32"
|
18 |
+
|
19 |
+
#define ST_BOOL "BOOL"
|
20 |
+
#define ST_I8 "I8"
|
21 |
+
#define ST_I16 "I16"
|
22 |
+
#define ST_I32 "I32"
|
23 |
+
#define ST_I64 "I64"
|
24 |
+
#define ST_U8 "U8"
|
25 |
+
#define ST_U16 "U16"
|
26 |
+
#define ST_U32 "U32"
|
27 |
+
#define ST_U64 "U64"
|
28 |
+
|
29 |
+
// Note: Complex numbers aren't in the spec yet so this could change -
|
30 |
+
// https://github.com/huggingface/safetensors/issues/389
|
31 |
+
#define ST_C64 "C64"
|
32 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/linalg.h
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include <optional>
|
6 |
+
|
7 |
+
#include "mlx/array.h"
|
8 |
+
#include "mlx/device.h"
|
9 |
+
#include "mlx/ops.h"
|
10 |
+
#include "mlx/stream.h"
|
11 |
+
|
12 |
+
namespace mlx::core::linalg {
|
13 |
+
|
14 |
+
/**
|
15 |
+
* Compute vector or matrix norms.
|
16 |
+
*
|
17 |
+
* - If axis and ord are both unspecified, computes the 2-norm of flatten(x).
|
18 |
+
* - If axis is not provided but ord is, then x must be either 1D or 2D.
|
19 |
+
* - If axis is provided, but ord is not, then the 2-norm (or Frobenius norm
|
20 |
+
* for matrices) is computed along the given axes. At most 2 axes can be
|
21 |
+
* specified.
|
22 |
+
* - If both axis and ord are provided, then the corresponding matrix or vector
|
23 |
+
* norm is computed. At most 2 axes can be specified.
|
24 |
+
*/
|
25 |
+
array norm(
|
26 |
+
const array& a,
|
27 |
+
const double ord,
|
28 |
+
const std::optional<std::vector<int>>& axis = std::nullopt,
|
29 |
+
bool keepdims = false,
|
30 |
+
StreamOrDevice s = {});
|
31 |
+
inline array norm(
|
32 |
+
const array& a,
|
33 |
+
const double ord,
|
34 |
+
int axis,
|
35 |
+
bool keepdims = false,
|
36 |
+
StreamOrDevice s = {}) {
|
37 |
+
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
|
38 |
+
}
|
39 |
+
array norm(
|
40 |
+
const array& a,
|
41 |
+
const std::string& ord,
|
42 |
+
const std::optional<std::vector<int>>& axis = std::nullopt,
|
43 |
+
bool keepdims = false,
|
44 |
+
StreamOrDevice s = {});
|
45 |
+
inline array norm(
|
46 |
+
const array& a,
|
47 |
+
const std::string& ord,
|
48 |
+
int axis,
|
49 |
+
bool keepdims = false,
|
50 |
+
StreamOrDevice s = {}) {
|
51 |
+
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
|
52 |
+
}
|
53 |
+
array norm(
|
54 |
+
const array& a,
|
55 |
+
const std::optional<std::vector<int>>& axis = std::nullopt,
|
56 |
+
bool keepdims = false,
|
57 |
+
StreamOrDevice s = {});
|
58 |
+
inline array
|
59 |
+
norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) {
|
60 |
+
return norm(a, std::vector<int>{axis}, keepdims, s);
|
61 |
+
}
|
62 |
+
|
63 |
+
} // namespace mlx::core::linalg
|
lib/python3.11/site-packages/mlx/include/mlx/mlx.h
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include "mlx/array.h"
|
6 |
+
#include "mlx/backend/metal/metal.h"
|
7 |
+
#include "mlx/device.h"
|
8 |
+
#include "mlx/fft.h"
|
9 |
+
#include "mlx/linalg.h"
|
10 |
+
#include "mlx/ops.h"
|
11 |
+
#include "mlx/random.h"
|
12 |
+
#include "mlx/stream.h"
|
13 |
+
#include "mlx/transforms.h"
|
14 |
+
#include "mlx/utils.h"
|
lib/python3.11/site-packages/mlx/include/mlx/ops.h
ADDED
@@ -0,0 +1,1094 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include <optional>
|
6 |
+
#include <variant>
|
7 |
+
|
8 |
+
#include "array.h"
|
9 |
+
#include "device.h"
|
10 |
+
#include "io/load.h"
|
11 |
+
#include "stream.h"
|
12 |
+
|
13 |
+
namespace mlx::core {
|
14 |
+
|
15 |
+
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
|
16 |
+
|
17 |
+
Stream to_stream(StreamOrDevice s);
|
18 |
+
|
19 |
+
/** Creation operations */
|
20 |
+
|
21 |
+
/**
|
22 |
+
* A 1D array of numbers starting at `start` (optional),
|
23 |
+
* stopping at stop, stepping by `step` (optional). */
|
24 |
+
array arange(
|
25 |
+
double start,
|
26 |
+
double stop,
|
27 |
+
double step,
|
28 |
+
Dtype dtype,
|
29 |
+
StreamOrDevice s = {});
|
30 |
+
array arange(double start, double stop, double step, StreamOrDevice s = {});
|
31 |
+
array arange(double start, double stop, Dtype dtype, StreamOrDevice s = {});
|
32 |
+
array arange(double start, double stop, StreamOrDevice s = {});
|
33 |
+
array arange(double stop, Dtype dtype, StreamOrDevice s = {});
|
34 |
+
array arange(double stop, StreamOrDevice s = {});
|
35 |
+
|
36 |
+
array arange(int start, int stop, int step, StreamOrDevice s = {});
|
37 |
+
array arange(int start, int stop, StreamOrDevice s = {});
|
38 |
+
array arange(int stop, StreamOrDevice s = {});
|
39 |
+
|
40 |
+
/** A 1D array of `num` evenly spaced numbers in the range `[start, stop]` */
|
41 |
+
array linspace(
|
42 |
+
double start,
|
43 |
+
double stop,
|
44 |
+
int num = 50,
|
45 |
+
Dtype dtype = float32,
|
46 |
+
StreamOrDevice s = {});
|
47 |
+
|
48 |
+
/** Convert an array to the given data type. */
|
49 |
+
array astype(const array& a, Dtype dtype, StreamOrDevice s = {});
|
50 |
+
|
51 |
+
/** Create a view of an array with the given shape and strides. */
|
52 |
+
array as_strided(
|
53 |
+
const array& a,
|
54 |
+
std::vector<int> shape,
|
55 |
+
std::vector<size_t> strides,
|
56 |
+
size_t offset,
|
57 |
+
StreamOrDevice s = {});
|
58 |
+
|
59 |
+
/** Copy another array. */
|
60 |
+
array copy(const array& a, StreamOrDevice s = {});
|
61 |
+
|
62 |
+
/** Fill an array of the given shape with the given value(s). */
|
63 |
+
array full(
|
64 |
+
const std::vector<int>& shape,
|
65 |
+
const array& vals,
|
66 |
+
Dtype dtype,
|
67 |
+
StreamOrDevice s = {});
|
68 |
+
array full(
|
69 |
+
const std::vector<int>& shape,
|
70 |
+
const array& vals,
|
71 |
+
StreamOrDevice s = {});
|
72 |
+
template <typename T>
|
73 |
+
array full(
|
74 |
+
const std::vector<int>& shape,
|
75 |
+
T val,
|
76 |
+
Dtype dtype,
|
77 |
+
StreamOrDevice s = {}) {
|
78 |
+
return full(shape, array(val, dtype), to_stream(s));
|
79 |
+
}
|
80 |
+
template <typename T>
|
81 |
+
array full(const std::vector<int>& shape, T val, StreamOrDevice s = {}) {
|
82 |
+
return full(shape, array(val), to_stream(s));
|
83 |
+
}
|
84 |
+
|
85 |
+
/** Fill an array of the given shape with zeros. */
|
86 |
+
array zeros(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
|
87 |
+
inline array zeros(const std::vector<int>& shape, StreamOrDevice s = {}) {
|
88 |
+
return zeros(shape, float32, s);
|
89 |
+
}
|
90 |
+
array zeros_like(const array& a, StreamOrDevice s = {});
|
91 |
+
|
92 |
+
/** Fill an array of the given shape with ones. */
|
93 |
+
array ones(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
|
94 |
+
inline array ones(const std::vector<int>& shape, StreamOrDevice s = {}) {
|
95 |
+
return ones(shape, float32, s);
|
96 |
+
}
|
97 |
+
array ones_like(const array& a, StreamOrDevice s = {});
|
98 |
+
|
99 |
+
/** Fill an array of the given shape (n,m) with ones in the specified diagonal
|
100 |
+
* k, and zeros everywhere else. */
|
101 |
+
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {});
|
102 |
+
inline array eye(int n, Dtype dtype, StreamOrDevice s = {}) {
|
103 |
+
return eye(n, n, 0, dtype, s);
|
104 |
+
}
|
105 |
+
inline array eye(int n, int m, StreamOrDevice s = {}) {
|
106 |
+
return eye(n, m, 0, float32, s);
|
107 |
+
}
|
108 |
+
inline array eye(int n, int m, int k, StreamOrDevice s = {}) {
|
109 |
+
return eye(n, m, k, float32, s);
|
110 |
+
}
|
111 |
+
inline array eye(int n, StreamOrDevice s = {}) {
|
112 |
+
return eye(n, n, 0, float32, s);
|
113 |
+
}
|
114 |
+
|
115 |
+
/** Create a square matrix of shape (n,n) of zeros, and ones in the major
|
116 |
+
* diagonal. */
|
117 |
+
array identity(int n, Dtype dtype, StreamOrDevice s = {});
|
118 |
+
inline array identity(int n, StreamOrDevice s = {}) {
|
119 |
+
return identity(n, float32, s);
|
120 |
+
}
|
121 |
+
|
122 |
+
array tri(int n, int m, int k, Dtype type, StreamOrDevice s = {});
|
123 |
+
inline array tri(int n, Dtype type, StreamOrDevice s = {}) {
|
124 |
+
return tri(n, n, 0, type, s);
|
125 |
+
}
|
126 |
+
|
127 |
+
array tril(array x, int k, StreamOrDevice s = {});
|
128 |
+
array triu(array x, int k, StreamOrDevice s = {});
|
129 |
+
|
130 |
+
/** array manipulation */
|
131 |
+
|
132 |
+
/** Reshape an array to the given shape. */
|
133 |
+
array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
|
134 |
+
|
135 |
+
/** Flatten the dimensions in the range `[start_axis, end_axis]` . */
|
136 |
+
array flatten(
|
137 |
+
const array& a,
|
138 |
+
int start_axis,
|
139 |
+
int end_axis = -1,
|
140 |
+
StreamOrDevice s = {});
|
141 |
+
|
142 |
+
/** Flatten the array to 1D. */
|
143 |
+
array flatten(const array& a, StreamOrDevice s = {});
|
144 |
+
|
145 |
+
/** Remove singleton dimensions at the given axes. */
|
146 |
+
array squeeze(
|
147 |
+
const array& a,
|
148 |
+
const std::vector<int>& axes,
|
149 |
+
StreamOrDevice s = {});
|
150 |
+
|
151 |
+
/** Remove singleton dimensions at the given axis. */
|
152 |
+
inline array squeeze(const array& a, int axis, StreamOrDevice s = {}) {
|
153 |
+
return squeeze(a, std::vector<int>{axis}, s);
|
154 |
+
}
|
155 |
+
|
156 |
+
/** Remove all singleton dimensions. */
|
157 |
+
array squeeze(const array& a, StreamOrDevice s = {});
|
158 |
+
|
159 |
+
/** Add a singleton dimension at the given axes. */
|
160 |
+
array expand_dims(
|
161 |
+
const array& a,
|
162 |
+
const std::vector<int>& axes,
|
163 |
+
StreamOrDevice s = {});
|
164 |
+
|
165 |
+
/** Add a singleton dimension at the given axis. */
|
166 |
+
inline array expand_dims(const array& a, int axis, StreamOrDevice s = {}) {
|
167 |
+
return expand_dims(a, std::vector<int>{axis}, s);
|
168 |
+
}
|
169 |
+
|
170 |
+
/** Slice an array. */
|
171 |
+
array slice(
|
172 |
+
const array& a,
|
173 |
+
std::vector<int> start,
|
174 |
+
std::vector<int> stop,
|
175 |
+
std::vector<int> strides,
|
176 |
+
StreamOrDevice s = {});
|
177 |
+
|
178 |
+
/** Slice an array with a stride of 1 in each dimension. */
|
179 |
+
array slice(
|
180 |
+
const array& a,
|
181 |
+
const std::vector<int>& start,
|
182 |
+
const std::vector<int>& stop,
|
183 |
+
StreamOrDevice s = {});
|
184 |
+
|
185 |
+
/** Split an array into sub-arrays along a given axis. */
|
186 |
+
std::vector<array>
|
187 |
+
split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
|
188 |
+
std::vector<array> split(const array& a, int num_splits, StreamOrDevice s = {});
|
189 |
+
std::vector<array> split(
|
190 |
+
const array& a,
|
191 |
+
const std::vector<int>& indices,
|
192 |
+
int axis,
|
193 |
+
StreamOrDevice s = {});
|
194 |
+
std::vector<array>
|
195 |
+
split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});
|
196 |
+
|
197 |
+
/**
|
198 |
+
* Clip (limit) the values in an array.
|
199 |
+
*/
|
200 |
+
array clip(
|
201 |
+
const array& a,
|
202 |
+
const std::optional<array>& a_min = std::nullopt,
|
203 |
+
const std::optional<array>& a_max = std::nullopt,
|
204 |
+
StreamOrDevice s = {});
|
205 |
+
|
206 |
+
/** Concatenate arrays along a given axis. */
|
207 |
+
array concatenate(
|
208 |
+
const std::vector<array>& arrays,
|
209 |
+
int axis,
|
210 |
+
StreamOrDevice s = {});
|
211 |
+
array concatenate(const std::vector<array>& arrays, StreamOrDevice s = {});
|
212 |
+
|
213 |
+
/** Stack arrays along a new axis. */
|
214 |
+
array stack(const std::vector<array>& arrays, int axis, StreamOrDevice s = {});
|
215 |
+
array stack(const std::vector<array>& arrays, StreamOrDevice s = {});
|
216 |
+
|
217 |
+
/** Repeat an array along an axis. */
|
218 |
+
array repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {});
|
219 |
+
array repeat(const array& arr, int repeats, StreamOrDevice s = {});
|
220 |
+
|
221 |
+
/** Permutes the dimensions according to the given axes. */
|
222 |
+
array transpose(const array& a, std::vector<int> axes, StreamOrDevice s = {});
|
223 |
+
inline array transpose(
|
224 |
+
const array& a,
|
225 |
+
std::initializer_list<int> axes,
|
226 |
+
StreamOrDevice s = {}) {
|
227 |
+
return transpose(a, std::vector<int>(axes), s);
|
228 |
+
}
|
229 |
+
|
230 |
+
/** Swap two axes of an array. */
|
231 |
+
array swapaxes(const array& a, int axis1, int axis2, StreamOrDevice s = {});
|
232 |
+
|
233 |
+
/** Move an axis of an array. */
|
234 |
+
array moveaxis(
|
235 |
+
const array& a,
|
236 |
+
int source,
|
237 |
+
int destination,
|
238 |
+
StreamOrDevice s = {});
|
239 |
+
|
240 |
+
/** Pad an array with a constant value */
|
241 |
+
array pad(
|
242 |
+
const array& a,
|
243 |
+
const std::vector<int>& axes,
|
244 |
+
const std::vector<int>& low_pad_size,
|
245 |
+
const std::vector<int>& high_pad_size,
|
246 |
+
const array& pad_value = array(0),
|
247 |
+
StreamOrDevice s = {});
|
248 |
+
|
249 |
+
/** Pad an array with a constant value along all axes */
|
250 |
+
array pad(
|
251 |
+
const array& a,
|
252 |
+
const std::vector<std::pair<int, int>>& pad_width,
|
253 |
+
const array& pad_value = array(0),
|
254 |
+
StreamOrDevice s = {});
|
255 |
+
array pad(
|
256 |
+
const array& a,
|
257 |
+
const std::pair<int, int>& pad_width,
|
258 |
+
const array& pad_value = array(0),
|
259 |
+
StreamOrDevice s = {});
|
260 |
+
array pad(
|
261 |
+
const array& a,
|
262 |
+
int pad_width,
|
263 |
+
const array& pad_value = array(0),
|
264 |
+
StreamOrDevice s = {});
|
265 |
+
|
266 |
+
/** Permutes the dimensions in reverse order. */
|
267 |
+
array transpose(const array& a, StreamOrDevice s = {});
|
268 |
+
|
269 |
+
/** Broadcast an array to a given shape. */
|
270 |
+
array broadcast_to(
|
271 |
+
const array& a,
|
272 |
+
const std::vector<int>& shape,
|
273 |
+
StreamOrDevice s = {});
|
274 |
+
|
275 |
+
/** Broadcast a vector of arrays against one another. */
|
276 |
+
std::vector<array> broadcast_arrays(
|
277 |
+
const std::vector<array>& inputs,
|
278 |
+
StreamOrDevice s = {});
|
279 |
+
|
280 |
+
/** Comparison operations */
|
281 |
+
|
282 |
+
/** Returns the bool array with (a == b) element-wise. */
|
283 |
+
array equal(const array& a, const array& b, StreamOrDevice s = {});
|
284 |
+
inline array operator==(const array& a, const array& b) {
|
285 |
+
return equal(a, b);
|
286 |
+
}
|
287 |
+
template <typename T>
|
288 |
+
array operator==(T a, const array& b) {
|
289 |
+
return equal(array(a), b);
|
290 |
+
}
|
291 |
+
template <typename T>
|
292 |
+
array operator==(const array& a, T b) {
|
293 |
+
return equal(a, array(b));
|
294 |
+
}
|
295 |
+
|
296 |
+
/** Returns the bool array with (a != b) element-wise. */
|
297 |
+
array not_equal(const array& a, const array& b, StreamOrDevice s = {});
|
298 |
+
inline array operator!=(const array& a, const array& b) {
|
299 |
+
return not_equal(a, b);
|
300 |
+
}
|
301 |
+
template <typename T>
|
302 |
+
array operator!=(T a, const array& b) {
|
303 |
+
return not_equal(array(a), b);
|
304 |
+
}
|
305 |
+
template <typename T>
|
306 |
+
array operator!=(const array& a, T b) {
|
307 |
+
return not_equal(a, array(b));
|
308 |
+
}
|
309 |
+
|
310 |
+
/** Returns bool array with (a > b) element-wise. */
|
311 |
+
array greater(const array& a, const array& b, StreamOrDevice s = {});
|
312 |
+
inline array operator>(const array& a, const array& b) {
|
313 |
+
return greater(a, b);
|
314 |
+
}
|
315 |
+
template <typename T>
|
316 |
+
array operator>(T a, const array& b) {
|
317 |
+
return greater(array(a), b);
|
318 |
+
}
|
319 |
+
template <typename T>
|
320 |
+
array operator>(const array& a, T b) {
|
321 |
+
return greater(a, array(b));
|
322 |
+
}
|
323 |
+
|
324 |
+
/** Returns bool array with (a >= b) element-wise. */
|
325 |
+
array greater_equal(const array& a, const array& b, StreamOrDevice s = {});
|
326 |
+
inline array operator>=(const array& a, const array& b) {
|
327 |
+
return greater_equal(a, b);
|
328 |
+
}
|
329 |
+
template <typename T>
|
330 |
+
array operator>=(T a, const array& b) {
|
331 |
+
return greater_equal(array(a), b);
|
332 |
+
}
|
333 |
+
template <typename T>
|
334 |
+
array operator>=(const array& a, T b) {
|
335 |
+
return greater_equal(a, array(b));
|
336 |
+
}
|
337 |
+
|
338 |
+
/** Returns bool array with (a < b) element-wise. */
|
339 |
+
array less(const array& a, const array& b, StreamOrDevice s = {});
|
340 |
+
inline array operator<(const array& a, const array& b) {
|
341 |
+
return less(a, b);
|
342 |
+
}
|
343 |
+
template <typename T>
|
344 |
+
array operator<(T a, const array& b) {
|
345 |
+
return less(array(a), b);
|
346 |
+
}
|
347 |
+
template <typename T>
|
348 |
+
array operator<(const array& a, T b) {
|
349 |
+
return less(a, array(b));
|
350 |
+
}
|
351 |
+
|
352 |
+
/** Returns bool array with (a <= b) element-wise. */
|
353 |
+
array less_equal(const array& a, const array& b, StreamOrDevice s = {});
|
354 |
+
inline array operator<=(const array& a, const array& b) {
|
355 |
+
return less_equal(a, b);
|
356 |
+
}
|
357 |
+
template <typename T>
|
358 |
+
array operator<=(T a, const array& b) {
|
359 |
+
return less_equal(array(a), b);
|
360 |
+
}
|
361 |
+
template <typename T>
|
362 |
+
array operator<=(const array& a, T b) {
|
363 |
+
return less_equal(a, array(b));
|
364 |
+
}
|
365 |
+
|
366 |
+
/** True if two arrays have the same shape and elements. */
|
367 |
+
array array_equal(
|
368 |
+
const array& a,
|
369 |
+
const array& b,
|
370 |
+
bool equal_nan,
|
371 |
+
StreamOrDevice s = {});
|
372 |
+
inline array
|
373 |
+
array_equal(const array& a, const array& b, StreamOrDevice s = {}) {
|
374 |
+
return array_equal(a, b, false, s);
|
375 |
+
}
|
376 |
+
|
377 |
+
/** Select from x or y depending on condition. */
|
378 |
+
array where(
|
379 |
+
const array& condition,
|
380 |
+
const array& x,
|
381 |
+
const array& y,
|
382 |
+
StreamOrDevice s = {});
|
383 |
+
|
384 |
+
/** Reduction operations */
|
385 |
+
|
386 |
+
/** True if all elements in the array are true (or non-zero). **/
|
387 |
+
array all(const array& a, bool keepdims, StreamOrDevice s = {});
|
388 |
+
inline array all(const array& a, StreamOrDevice s = {}) {
|
389 |
+
return all(a, false, to_stream(s));
|
390 |
+
}
|
391 |
+
|
392 |
+
/** True if the two arrays are equal within the specified tolerance. */
|
393 |
+
array allclose(
|
394 |
+
const array& a,
|
395 |
+
const array& b,
|
396 |
+
double rtol = 1e-5,
|
397 |
+
double atol = 1e-8,
|
398 |
+
StreamOrDevice s = {});
|
399 |
+
|
400 |
+
/**
|
401 |
+
* Reduces the input along the given axes. An output value is true
|
402 |
+
* if all the corresponding inputs are true.
|
403 |
+
**/
|
404 |
+
array all(
|
405 |
+
const array& a,
|
406 |
+
const std::vector<int>& axes,
|
407 |
+
bool keepdims = false,
|
408 |
+
StreamOrDevice s = {});
|
409 |
+
|
410 |
+
/**
|
411 |
+
* Reduces the input along the given axis. An output value is true
|
412 |
+
* if all the corresponding inputs are true.
|
413 |
+
**/
|
414 |
+
array all(
|
415 |
+
const array& a,
|
416 |
+
int axis,
|
417 |
+
bool keepdims = false,
|
418 |
+
StreamOrDevice s = {});
|
419 |
+
|
420 |
+
/** True if any elements in the array are true (or non-zero). **/
|
421 |
+
array any(const array& a, bool keepdims, StreamOrDevice s = {});
|
422 |
+
inline array any(const array& a, StreamOrDevice s = {}) {
|
423 |
+
return any(a, false, to_stream(s));
|
424 |
+
}
|
425 |
+
|
426 |
+
/**
|
427 |
+
* Reduces the input along the given axes. An output value is true
|
428 |
+
* if any of the corresponding inputs are true.
|
429 |
+
**/
|
430 |
+
array any(
|
431 |
+
const array& a,
|
432 |
+
const std::vector<int>& axes,
|
433 |
+
bool keepdims = false,
|
434 |
+
StreamOrDevice s = {});
|
435 |
+
|
436 |
+
/**
|
437 |
+
* Reduces the input along the given axis. An output value is true
|
438 |
+
* if any of the corresponding inputs are true.
|
439 |
+
**/
|
440 |
+
array any(
|
441 |
+
const array& a,
|
442 |
+
int axis,
|
443 |
+
bool keepdims = false,
|
444 |
+
StreamOrDevice s = {});
|
445 |
+
|
446 |
+
/** Sums the elements of an array. */
|
447 |
+
array sum(const array& a, bool keepdims, StreamOrDevice s = {});
|
448 |
+
inline array sum(const array& a, StreamOrDevice s = {}) {
|
449 |
+
return sum(a, false, to_stream(s));
|
450 |
+
}
|
451 |
+
|
452 |
+
/** Sums the elements of an array along the given axes. */
|
453 |
+
array sum(
|
454 |
+
const array& a,
|
455 |
+
const std::vector<int>& axes,
|
456 |
+
bool keepdims = false,
|
457 |
+
StreamOrDevice s = {});
|
458 |
+
|
459 |
+
/** Sums the elements of an array along the given axis. */
|
460 |
+
array sum(
|
461 |
+
const array& a,
|
462 |
+
int axis,
|
463 |
+
bool keepdims = false,
|
464 |
+
StreamOrDevice s = {});
|
465 |
+
|
466 |
+
/** Computes the mean of the elements of an array. */
|
467 |
+
array mean(const array& a, bool keepdims, StreamOrDevice s = {});
|
468 |
+
inline array mean(const array& a, StreamOrDevice s = {}) {
|
469 |
+
return mean(a, false, to_stream(s));
|
470 |
+
}
|
471 |
+
|
472 |
+
/** Computes the mean of the elements of an array along the given axes */
|
473 |
+
array mean(
|
474 |
+
const array& a,
|
475 |
+
const std::vector<int>& axes,
|
476 |
+
bool keepdims = false,
|
477 |
+
StreamOrDevice s = {});
|
478 |
+
|
479 |
+
/** Computes the mean of the elements of an array along the given axis */
|
480 |
+
array mean(
|
481 |
+
const array& a,
|
482 |
+
int axis,
|
483 |
+
bool keepdims = false,
|
484 |
+
StreamOrDevice s = {});
|
485 |
+
|
486 |
+
/** Computes the mean of the elements of an array. */
|
487 |
+
array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
|
488 |
+
inline array var(const array& a, StreamOrDevice s = {}) {
|
489 |
+
return var(a, false, 0, to_stream(s));
|
490 |
+
}
|
491 |
+
|
492 |
+
/** Computes the var of the elements of an array along the given axes */
|
493 |
+
array var(
|
494 |
+
const array& a,
|
495 |
+
const std::vector<int>& axes,
|
496 |
+
bool keepdims = false,
|
497 |
+
int ddof = 0,
|
498 |
+
StreamOrDevice s = {});
|
499 |
+
|
500 |
+
/** Computes the var of the elements of an array along the given axis */
|
501 |
+
array var(
|
502 |
+
const array& a,
|
503 |
+
int axis,
|
504 |
+
bool keepdims = false,
|
505 |
+
int ddof = 0,
|
506 |
+
StreamOrDevice s = {});
|
507 |
+
|
508 |
+
/** The product of all elements of the array. */
|
509 |
+
array prod(const array& a, bool keepdims, StreamOrDevice s = {});
|
510 |
+
inline array prod(const array& a, StreamOrDevice s = {}) {
|
511 |
+
return prod(a, false, to_stream(s));
|
512 |
+
}
|
513 |
+
|
514 |
+
/** The product of the elements of an array along the given axes. */
|
515 |
+
array prod(
|
516 |
+
const array& a,
|
517 |
+
const std::vector<int>& axes,
|
518 |
+
bool keepdims = false,
|
519 |
+
StreamOrDevice s = {});
|
520 |
+
|
521 |
+
/** The product of the elements of an array along the given axis. */
|
522 |
+
array prod(
|
523 |
+
const array& a,
|
524 |
+
int axis,
|
525 |
+
bool keepdims = false,
|
526 |
+
StreamOrDevice s = {});
|
527 |
+
|
528 |
+
/** The maximum of all elements of the array. */
|
529 |
+
array max(const array& a, bool keepdims, StreamOrDevice s = {});
|
530 |
+
inline array max(const array& a, StreamOrDevice s = {}) {
|
531 |
+
return max(a, false, to_stream(s));
|
532 |
+
}
|
533 |
+
|
534 |
+
/** The maximum of the elements of an array along the given axes. */
|
535 |
+
array max(
|
536 |
+
const array& a,
|
537 |
+
const std::vector<int>& axes,
|
538 |
+
bool keepdims = false,
|
539 |
+
StreamOrDevice s = {});
|
540 |
+
|
541 |
+
/** The maximum of the elements of an array along the given axis. */
|
542 |
+
array max(
|
543 |
+
const array& a,
|
544 |
+
int axis,
|
545 |
+
bool keepdims = false,
|
546 |
+
StreamOrDevice s = {});
|
547 |
+
|
548 |
+
/** The minimum of all elements of the array. */
|
549 |
+
array min(const array& a, bool keepdims, StreamOrDevice s = {});
|
550 |
+
inline array min(const array& a, StreamOrDevice s = {}) {
|
551 |
+
return min(a, false, to_stream(s));
|
552 |
+
}
|
553 |
+
|
554 |
+
/** The minimum of the elements of an array along the given axes. */
|
555 |
+
array min(
|
556 |
+
const array& a,
|
557 |
+
const std::vector<int>& axes,
|
558 |
+
bool keepdims = false,
|
559 |
+
StreamOrDevice s = {});
|
560 |
+
|
561 |
+
/** The minimum of the elements of an array along the given axis. */
|
562 |
+
array min(
|
563 |
+
const array& a,
|
564 |
+
int axis,
|
565 |
+
bool keepdims = false,
|
566 |
+
StreamOrDevice s = {});
|
567 |
+
|
568 |
+
/** Returns the index of the minimum value in the array. */
|
569 |
+
array argmin(const array& a, bool keepdims, StreamOrDevice s = {});
|
570 |
+
inline array argmin(const array& a, StreamOrDevice s = {}) {
|
571 |
+
return argmin(a, false, s);
|
572 |
+
}
|
573 |
+
|
574 |
+
/** Returns the indices of the minimum values along a given axis. */
|
575 |
+
array argmin(
|
576 |
+
const array& a,
|
577 |
+
int axis,
|
578 |
+
bool keepdims = false,
|
579 |
+
StreamOrDevice s = {});
|
580 |
+
|
581 |
+
/** Returns the index of the maximum value in the array. */
|
582 |
+
array argmax(const array& a, bool keepdims, StreamOrDevice s = {});
|
583 |
+
inline array argmax(const array& a, StreamOrDevice s = {}) {
|
584 |
+
return argmax(a, false, s);
|
585 |
+
}
|
586 |
+
|
587 |
+
/** Returns the indices of the maximum values along a given axis. */
|
588 |
+
array argmax(
|
589 |
+
const array& a,
|
590 |
+
int axis,
|
591 |
+
bool keepdims = false,
|
592 |
+
StreamOrDevice s = {});
|
593 |
+
|
594 |
+
/** Returns a sorted copy of the flattened array. */
|
595 |
+
array sort(const array& a, StreamOrDevice s = {});
|
596 |
+
|
597 |
+
/** Returns a sorted copy of the array along a given axis. */
|
598 |
+
array sort(const array& a, int axis, StreamOrDevice s = {});
|
599 |
+
|
600 |
+
/** Returns indices that sort the flattened array. */
|
601 |
+
array argsort(const array& a, StreamOrDevice s = {});
|
602 |
+
|
603 |
+
/** Returns indices that sort the array along a given axis. */
|
604 |
+
array argsort(const array& a, int axis, StreamOrDevice s = {});
|
605 |
+
|
606 |
+
/**
|
607 |
+
* Returns a partitioned copy of the flattened array
|
608 |
+
* such that the smaller kth elements are first.
|
609 |
+
**/
|
610 |
+
array partition(const array& a, int kth, StreamOrDevice s = {});
|
611 |
+
|
612 |
+
/**
|
613 |
+
* Returns a partitioned copy of the array along a given axis
|
614 |
+
* such that the smaller kth elements are first.
|
615 |
+
**/
|
616 |
+
array partition(const array& a, int kth, int axis, StreamOrDevice s = {});
|
617 |
+
|
618 |
+
/**
|
619 |
+
* Returns indices that partition the flattened array
|
620 |
+
* such that the smaller kth elements are first.
|
621 |
+
**/
|
622 |
+
array argpartition(const array& a, int kth, StreamOrDevice s = {});
|
623 |
+
|
624 |
+
/**
|
625 |
+
* Returns indices that partition the array along a given axis
|
626 |
+
* such that the smaller kth elements are first.
|
627 |
+
**/
|
628 |
+
array argpartition(const array& a, int kth, int axis, StreamOrDevice s = {});
|
629 |
+
|
630 |
+
/** Returns topk elements of the flattened array. */
|
631 |
+
array topk(const array& a, int k, StreamOrDevice s = {});
|
632 |
+
|
633 |
+
/** Returns topk elements of the array along a given axis. */
|
634 |
+
array topk(const array& a, int k, int axis, StreamOrDevice s = {});
|
635 |
+
|
636 |
+
/** The logsumexp of all elements of the array. */
|
637 |
+
array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {});
|
638 |
+
inline array logsumexp(const array& a, StreamOrDevice s = {}) {
|
639 |
+
return logsumexp(a, false, to_stream(s));
|
640 |
+
}
|
641 |
+
|
642 |
+
/** The logsumexp of the elements of an array along the given axes. */
|
643 |
+
array logsumexp(
|
644 |
+
const array& a,
|
645 |
+
const std::vector<int>& axes,
|
646 |
+
bool keepdims = false,
|
647 |
+
StreamOrDevice s = {});
|
648 |
+
|
649 |
+
/** The logsumexp of the elements of an array along the given axis. */
|
650 |
+
array logsumexp(
|
651 |
+
const array& a,
|
652 |
+
int axis,
|
653 |
+
bool keepdims = false,
|
654 |
+
StreamOrDevice s = {});
|
655 |
+
|
656 |
+
/** Simple arithmetic operations */
|
657 |
+
|
658 |
+
/** Absolute value of elements in an array. */
|
659 |
+
array abs(const array& a, StreamOrDevice s = {});
|
660 |
+
|
661 |
+
/** Negate an array. */
|
662 |
+
array negative(const array& a, StreamOrDevice s = {});
|
663 |
+
array operator-(const array& a);
|
664 |
+
|
665 |
+
/** The sign of the elements in an array. */
|
666 |
+
array sign(const array& a, StreamOrDevice s = {});
|
667 |
+
|
668 |
+
/** Logical not of an array */
|
669 |
+
array logical_not(const array& a, StreamOrDevice s = {});
|
670 |
+
|
671 |
+
/** The reciprocal (1/x) of the elements in an array. */
|
672 |
+
array reciprocal(const array& a, StreamOrDevice s = {});
|
673 |
+
|
674 |
+
/** Add two arrays. */
|
675 |
+
array add(const array& a, const array& b, StreamOrDevice s = {});
|
676 |
+
array operator+(const array& a, const array& b);
|
677 |
+
template <typename T>
|
678 |
+
array operator+(T a, const array& b) {
|
679 |
+
return add(array(a), b);
|
680 |
+
}
|
681 |
+
template <typename T>
|
682 |
+
array operator+(const array& a, T b) {
|
683 |
+
return add(a, array(b));
|
684 |
+
}
|
685 |
+
|
686 |
+
/** Subtract two arrays. */
|
687 |
+
array subtract(const array& a, const array& b, StreamOrDevice s = {});
|
688 |
+
array operator-(const array& a, const array& b);
|
689 |
+
template <typename T>
|
690 |
+
array operator-(T a, const array& b) {
|
691 |
+
return subtract(array(a), b);
|
692 |
+
}
|
693 |
+
template <typename T>
|
694 |
+
array operator-(const array& a, T b) {
|
695 |
+
return subtract(a, array(b));
|
696 |
+
}
|
697 |
+
|
698 |
+
/** Multiply two arrays. */
|
699 |
+
array multiply(const array& a, const array& b, StreamOrDevice s = {});
|
700 |
+
array operator*(const array& a, const array& b);
|
701 |
+
template <typename T>
|
702 |
+
array operator*(T a, const array& b) {
|
703 |
+
return multiply(array(a), b);
|
704 |
+
}
|
705 |
+
template <typename T>
|
706 |
+
array operator*(const array& a, T b) {
|
707 |
+
return multiply(a, array(b));
|
708 |
+
}
|
709 |
+
|
710 |
+
/** Divide two arrays. */
|
711 |
+
array divide(const array& a, const array& b, StreamOrDevice s = {});
|
712 |
+
array operator/(const array& a, const array& b);
|
713 |
+
array operator/(double a, const array& b);
|
714 |
+
array operator/(const array& a, double b);
|
715 |
+
|
716 |
+
/** Compute integer division. Equivalent to doing floor(a / x). */
|
717 |
+
array floor_divide(const array& a, const array& b, StreamOrDevice s = {});
|
718 |
+
|
719 |
+
/** Compute the element-wise remainder of division */
|
720 |
+
array remainder(const array& a, const array& b, StreamOrDevice s = {});
|
721 |
+
array operator%(const array& a, const array& b);
|
722 |
+
template <typename T>
|
723 |
+
array operator%(T a, const array& b) {
|
724 |
+
return remainder(array(a), b);
|
725 |
+
}
|
726 |
+
template <typename T>
|
727 |
+
array operator%(const array& a, T b) {
|
728 |
+
return remainder(a, array(b));
|
729 |
+
}
|
730 |
+
|
731 |
+
/** Element-wise maximum between two arrays. */
|
732 |
+
array maximum(const array& a, const array& b, StreamOrDevice s = {});
|
733 |
+
|
734 |
+
/** Element-wise minimum between two arrays. */
|
735 |
+
array minimum(const array& a, const array& b, StreamOrDevice s = {});
|
736 |
+
|
737 |
+
/** Floor the element of an array. **/
|
738 |
+
array floor(const array& a, StreamOrDevice s = {});
|
739 |
+
|
740 |
+
/** Ceil the element of an array. **/
|
741 |
+
array ceil(const array& a, StreamOrDevice s = {});
|
742 |
+
|
743 |
+
/** Square the elements of an array. */
|
744 |
+
array square(const array& a, StreamOrDevice s = {});
|
745 |
+
|
746 |
+
/** Exponential of the elements of an array. */
|
747 |
+
array exp(const array& a, StreamOrDevice s = {});
|
748 |
+
|
749 |
+
/** Sine of the elements of an array */
|
750 |
+
array sin(const array& a, StreamOrDevice s = {});
|
751 |
+
|
752 |
+
/** Cosine of the elements of an array */
|
753 |
+
array cos(const array& a, StreamOrDevice s = {});
|
754 |
+
|
755 |
+
/** Tangent of the elements of an array */
|
756 |
+
array tan(const array& a, StreamOrDevice s = {});
|
757 |
+
|
758 |
+
/** Arc Sine of the elements of an array */
|
759 |
+
array arcsin(const array& a, StreamOrDevice s = {});
|
760 |
+
|
761 |
+
/** Arc Cosine of the elements of an array */
|
762 |
+
array arccos(const array& a, StreamOrDevice s = {});
|
763 |
+
|
764 |
+
/** Arc Tangent of the elements of an array */
|
765 |
+
array arctan(const array& a, StreamOrDevice s = {});
|
766 |
+
|
767 |
+
/** Hyperbolic Sine of the elements of an array */
|
768 |
+
array sinh(const array& a, StreamOrDevice s = {});
|
769 |
+
|
770 |
+
/** Hyperbolic Cosine of the elements of an array */
|
771 |
+
array cosh(const array& a, StreamOrDevice s = {});
|
772 |
+
|
773 |
+
/** Hyperbolic Tangent of the elements of an array */
|
774 |
+
array tanh(const array& a, StreamOrDevice s = {});
|
775 |
+
|
776 |
+
/** Inverse Hyperbolic Sine of the elements of an array */
|
777 |
+
array arcsinh(const array& a, StreamOrDevice s = {});
|
778 |
+
|
779 |
+
/** Inverse Hyperbolic Cosine of the elements of an array */
|
780 |
+
array arccosh(const array& a, StreamOrDevice s = {});
|
781 |
+
|
782 |
+
/** Inverse Hyperbolic Tangent of the elements of an array */
|
783 |
+
array arctanh(const array& a, StreamOrDevice s = {});
|
784 |
+
|
785 |
+
/** Natural logarithm of the elements of an array. */
|
786 |
+
array log(const array& a, StreamOrDevice s = {});
|
787 |
+
|
788 |
+
/** Log base 2 of the elements of an array. */
|
789 |
+
array log2(const array& a, StreamOrDevice s = {});
|
790 |
+
|
791 |
+
/** Log base 10 of the elements of an array. */
|
792 |
+
array log10(const array& a, StreamOrDevice s = {});
|
793 |
+
|
794 |
+
/** Natural logarithm of one plus elements in the array: `log(1 + a)`. */
|
795 |
+
array log1p(const array& a, StreamOrDevice s = {});
|
796 |
+
|
797 |
+
/** Log-add-exp of one elements in the array: `log(exp(a) + exp(b))`. */
|
798 |
+
array logaddexp(const array& a, const array& b, StreamOrDevice s = {});
|
799 |
+
|
800 |
+
/** Element-wise logistic sigmoid of the array: `1 / (1 + exp(-x)`. */
|
801 |
+
array sigmoid(const array& a, StreamOrDevice s = {});
|
802 |
+
|
803 |
+
/** Computes the error function of the elements of an array. */
|
804 |
+
array erf(const array& a, StreamOrDevice s = {});
|
805 |
+
|
806 |
+
/** Computes the inverse error function of the elements of an array. */
|
807 |
+
array erfinv(const array& a, StreamOrDevice s = {});
|
808 |
+
|
809 |
+
/** Stop the flow of gradients. */
|
810 |
+
array stop_gradient(const array& a, StreamOrDevice s = {});
|
811 |
+
|
812 |
+
/** Round a floating point number */
|
813 |
+
array round(const array& a, int decimals, StreamOrDevice s = {});
|
814 |
+
inline array round(const array& a, StreamOrDevice s = {}) {
|
815 |
+
return round(a, 0, s);
|
816 |
+
}
|
817 |
+
|
818 |
+
/** Matrix-matrix multiplication. */
|
819 |
+
array matmul(const array& a, const array& b, StreamOrDevice s = {});
|
820 |
+
|
821 |
+
/** Gather array entries given indices and slices */
|
822 |
+
array gather(
|
823 |
+
const array& a,
|
824 |
+
const std::vector<array>& indices,
|
825 |
+
const std::vector<int>& axes,
|
826 |
+
const std::vector<int>& slice_sizes,
|
827 |
+
StreamOrDevice s = {});
|
828 |
+
inline array gather(
|
829 |
+
const array& a,
|
830 |
+
const array& indices,
|
831 |
+
int axis,
|
832 |
+
const std::vector<int>& slice_sizes,
|
833 |
+
StreamOrDevice s = {}) {
|
834 |
+
return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
|
835 |
+
}
|
836 |
+
|
837 |
+
/** Take array slices at the given indices of the specified axis. */
|
838 |
+
array take(
|
839 |
+
const array& a,
|
840 |
+
const array& indices,
|
841 |
+
int axis,
|
842 |
+
StreamOrDevice s = {});
|
843 |
+
|
844 |
+
/** Take array entries at the given indices treating the array as flattened. */
|
845 |
+
array take(const array& a, const array& indices, StreamOrDevice s = {});
|
846 |
+
|
847 |
+
/** Take array entries given indices along the axis */
|
848 |
+
array take_along_axis(
|
849 |
+
const array& a,
|
850 |
+
const array& indices,
|
851 |
+
int axis,
|
852 |
+
StreamOrDevice s = {});
|
853 |
+
|
854 |
+
/** Scatter updates to given linear indices */
|
855 |
+
array scatter(
|
856 |
+
const array& a,
|
857 |
+
const std::vector<array>& indices,
|
858 |
+
const array& updates,
|
859 |
+
const std::vector<int>& axes,
|
860 |
+
StreamOrDevice s = {});
|
861 |
+
inline array scatter(
|
862 |
+
const array& a,
|
863 |
+
const array& indices,
|
864 |
+
const array& updates,
|
865 |
+
int axis,
|
866 |
+
StreamOrDevice s = {}) {
|
867 |
+
return scatter(a, {indices}, updates, std::vector<int>{axis}, s);
|
868 |
+
}
|
869 |
+
|
870 |
+
/** Scatter and add updates to given indices */
|
871 |
+
array scatter_add(
|
872 |
+
const array& a,
|
873 |
+
const std::vector<array>& indices,
|
874 |
+
const array& updates,
|
875 |
+
const std::vector<int>& axes,
|
876 |
+
StreamOrDevice s = {});
|
877 |
+
inline array scatter_add(
|
878 |
+
const array& a,
|
879 |
+
const array& indices,
|
880 |
+
const array& updates,
|
881 |
+
int axis,
|
882 |
+
StreamOrDevice s = {}) {
|
883 |
+
return scatter_add(a, {indices}, updates, std::vector<int>{axis}, s);
|
884 |
+
}
|
885 |
+
|
886 |
+
/** Scatter and prod updates to given indices */
|
887 |
+
array scatter_prod(
|
888 |
+
const array& a,
|
889 |
+
const std::vector<array>& indices,
|
890 |
+
const array& updates,
|
891 |
+
const std::vector<int>& axes,
|
892 |
+
StreamOrDevice s = {});
|
893 |
+
inline array scatter_prod(
|
894 |
+
const array& a,
|
895 |
+
const array& indices,
|
896 |
+
const array& updates,
|
897 |
+
int axis,
|
898 |
+
StreamOrDevice s = {}) {
|
899 |
+
return scatter_prod(a, {indices}, updates, std::vector<int>{axis}, s);
|
900 |
+
}
|
901 |
+
|
902 |
+
/** Scatter and max updates to given linear indices */
|
903 |
+
array scatter_max(
|
904 |
+
const array& a,
|
905 |
+
const std::vector<array>& indices,
|
906 |
+
const array& updates,
|
907 |
+
const std::vector<int>& axes,
|
908 |
+
StreamOrDevice s = {});
|
909 |
+
inline array scatter_max(
|
910 |
+
const array& a,
|
911 |
+
const array& indices,
|
912 |
+
const array& updates,
|
913 |
+
int axis,
|
914 |
+
StreamOrDevice s = {}) {
|
915 |
+
return scatter_max(a, {indices}, updates, std::vector<int>{axis}, s);
|
916 |
+
}
|
917 |
+
/** Scatter and min updates to given linear indices */
|
918 |
+
array scatter_min(
|
919 |
+
const array& a,
|
920 |
+
const std::vector<array>& indices,
|
921 |
+
const array& updates,
|
922 |
+
const std::vector<int>& axes,
|
923 |
+
StreamOrDevice s = {});
|
924 |
+
inline array scatter_min(
|
925 |
+
const array& a,
|
926 |
+
const array& indices,
|
927 |
+
const array& updates,
|
928 |
+
int axis,
|
929 |
+
StreamOrDevice s = {}) {
|
930 |
+
return scatter_min(a, {indices}, updates, std::vector<int>{axis}, s);
|
931 |
+
}
|
932 |
+
|
933 |
+
/** Square root the elements of an array. */
|
934 |
+
array sqrt(const array& a, StreamOrDevice s = {});
|
935 |
+
|
936 |
+
/** Square root and reciprocal the elements of an array. */
|
937 |
+
array rsqrt(const array& a, StreamOrDevice s = {});
|
938 |
+
|
939 |
+
/** Softmax of an array. */
|
940 |
+
array softmax(
|
941 |
+
const array& a,
|
942 |
+
const std::vector<int>& axes,
|
943 |
+
StreamOrDevice s = {});
|
944 |
+
|
945 |
+
/** Softmax of an array. */
|
946 |
+
array softmax(const array& a, StreamOrDevice s = {});
|
947 |
+
|
948 |
+
/** Softmax of an array. */
|
949 |
+
inline array softmax(const array& a, int axis, StreamOrDevice s = {}) {
|
950 |
+
return softmax(a, std::vector<int>{axis}, s);
|
951 |
+
}
|
952 |
+
|
953 |
+
/** Raise elements of a to the power of b element-wise */
|
954 |
+
array power(const array& a, const array& b, StreamOrDevice s = {});
|
955 |
+
inline array operator^(const array& a, const array& b) {
|
956 |
+
return power(a, b);
|
957 |
+
}
|
958 |
+
template <typename T>
|
959 |
+
array operator^(T a, const array& b) {
|
960 |
+
return power(array(a), b);
|
961 |
+
}
|
962 |
+
template <typename T>
|
963 |
+
array operator^(const array& a, T b) {
|
964 |
+
return power(a, array(b));
|
965 |
+
}
|
966 |
+
|
967 |
+
/** Cumulative sum of an array. */
|
968 |
+
array cumsum(
|
969 |
+
const array& a,
|
970 |
+
int axis,
|
971 |
+
bool reverse = false,
|
972 |
+
bool inclusive = true,
|
973 |
+
StreamOrDevice s = {});
|
974 |
+
|
975 |
+
/** Cumulative product of an array. */
|
976 |
+
array cumprod(
|
977 |
+
const array& a,
|
978 |
+
int axis,
|
979 |
+
bool reverse = false,
|
980 |
+
bool inclusive = true,
|
981 |
+
StreamOrDevice s = {});
|
982 |
+
|
983 |
+
/** Cumulative max of an array. */
|
984 |
+
array cummax(
|
985 |
+
const array& a,
|
986 |
+
int axis,
|
987 |
+
bool reverse = false,
|
988 |
+
bool inclusive = true,
|
989 |
+
StreamOrDevice s = {});
|
990 |
+
|
991 |
+
/** Cumulative min of an array. */
|
992 |
+
array cummin(
|
993 |
+
const array& a,
|
994 |
+
int axis,
|
995 |
+
bool reverse = false,
|
996 |
+
bool inclusive = true,
|
997 |
+
StreamOrDevice s = {});
|
998 |
+
|
999 |
+
/** Convolution operations */
|
1000 |
+
|
1001 |
+
/** 1D convolution with a filter */
|
1002 |
+
array conv1d(
|
1003 |
+
const array& input,
|
1004 |
+
const array& weight,
|
1005 |
+
int stride = 1,
|
1006 |
+
int padding = 0,
|
1007 |
+
int dilation = 1,
|
1008 |
+
int groups = 1,
|
1009 |
+
StreamOrDevice s = {});
|
1010 |
+
|
1011 |
+
/** 2D convolution with a filter */
|
1012 |
+
array conv2d(
|
1013 |
+
const array& input,
|
1014 |
+
const array& weight,
|
1015 |
+
const std::pair<int, int>& stride = {1, 1},
|
1016 |
+
const std::pair<int, int>& padding = {0, 0},
|
1017 |
+
const std::pair<int, int>& dilation = {1, 1},
|
1018 |
+
int groups = 1,
|
1019 |
+
StreamOrDevice s = {});
|
1020 |
+
|
1021 |
+
/** Serialization operations */
|
1022 |
+
|
1023 |
+
/** Save array to out stream in .npy format */
|
1024 |
+
void save(
|
1025 |
+
std::shared_ptr<io::Writer> out_stream,
|
1026 |
+
array a,
|
1027 |
+
bool retain_graph = true);
|
1028 |
+
|
1029 |
+
/** Save array to file in .npy format */
|
1030 |
+
void save(const std::string& file, array a, bool retain_graph = true);
|
1031 |
+
|
1032 |
+
/** Load array from reader in .npy format */
|
1033 |
+
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
|
1034 |
+
|
1035 |
+
/** Load array from file in .npy format */
|
1036 |
+
array load(const std::string& file, StreamOrDevice s = {});
|
1037 |
+
|
1038 |
+
/** Quantized matmul multiplies x with a quantized matrix w*/
|
1039 |
+
array quantized_matmul(
|
1040 |
+
const array& x,
|
1041 |
+
const array& w,
|
1042 |
+
const array& scales,
|
1043 |
+
const array& biases,
|
1044 |
+
bool transpose = true,
|
1045 |
+
int group_size = 64,
|
1046 |
+
int bits = 4,
|
1047 |
+
StreamOrDevice s = {});
|
1048 |
+
|
1049 |
+
/** Quantize a matrix along its last axis */
|
1050 |
+
std::tuple<array, array, array> quantize(
|
1051 |
+
const array& w,
|
1052 |
+
int group_size = 64,
|
1053 |
+
int bits = 4,
|
1054 |
+
StreamOrDevice s = {});
|
1055 |
+
|
1056 |
+
/** Dequantize a matrix produced by quantize() */
|
1057 |
+
array dequantize(
|
1058 |
+
const array& w,
|
1059 |
+
const array& scales,
|
1060 |
+
const array& biases,
|
1061 |
+
int group_size = 64,
|
1062 |
+
int bits = 4,
|
1063 |
+
StreamOrDevice s = {});
|
1064 |
+
|
1065 |
+
/** TensorDot returns a contraction of a and b over multiple dimensions. */
|
1066 |
+
array tensordot(
|
1067 |
+
const array& a,
|
1068 |
+
const array& b,
|
1069 |
+
const int dims = 2,
|
1070 |
+
StreamOrDevice s = {});
|
1071 |
+
|
1072 |
+
array tensordot(
|
1073 |
+
const array& a,
|
1074 |
+
const array& b,
|
1075 |
+
const std::pair<std::vector<int>, std::vector<int>>& dims,
|
1076 |
+
StreamOrDevice s = {});
|
1077 |
+
|
1078 |
+
/** Load array map from .safetensors file format */
|
1079 |
+
std::unordered_map<std::string, array> load_safetensors(
|
1080 |
+
std::shared_ptr<io::Reader> in_stream,
|
1081 |
+
StreamOrDevice s = {});
|
1082 |
+
std::unordered_map<std::string, array> load_safetensors(
|
1083 |
+
const std::string& file,
|
1084 |
+
StreamOrDevice s = {});
|
1085 |
+
|
1086 |
+
void save_safetensors(
|
1087 |
+
std::shared_ptr<io::Writer> in_stream,
|
1088 |
+
std::unordered_map<std::string, array>,
|
1089 |
+
std::optional<bool> retain_graph = std::nullopt);
|
1090 |
+
void save_safetensors(
|
1091 |
+
const std::string& file,
|
1092 |
+
std::unordered_map<std::string, array>,
|
1093 |
+
std::optional<bool> retain_graph = std::nullopt);
|
1094 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/primitives.h
ADDED
@@ -0,0 +1,1636 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include "array.h"
|
6 |
+
#include "device.h"
|
7 |
+
#include "io/load.h"
|
8 |
+
#include "stream.h"
|
9 |
+
|
10 |
+
#define DEFINE_GRADS() \
|
11 |
+
array jvp( \
|
12 |
+
const std::vector<array>& primals, \
|
13 |
+
const std::vector<array>& tangents, \
|
14 |
+
const std::vector<int>& argnums) override; \
|
15 |
+
\
|
16 |
+
std::vector<array> vjp( \
|
17 |
+
const std::vector<array>& primals, \
|
18 |
+
const array& cotan, \
|
19 |
+
const std::vector<int>& argnums) override;
|
20 |
+
|
21 |
+
#define DEFINE_PRINT(PRIMITIVE) \
|
22 |
+
void print(std::ostream& os) override { \
|
23 |
+
os << #PRIMITIVE; \
|
24 |
+
}
|
25 |
+
|
26 |
+
#define DEFINE_DEFAULT_IS_EQUIVALENT() \
|
27 |
+
bool is_equivalent(const Primitive& other) const override { \
|
28 |
+
return true; \
|
29 |
+
}
|
30 |
+
|
31 |
+
namespace mlx::core {
|
32 |
+
|
33 |
+
// Abstract base class
|
34 |
+
class Primitive {
|
35 |
+
public:
|
36 |
+
explicit Primitive(Stream stream) : stream_(stream) {}
|
37 |
+
|
38 |
+
/** The device the primitive will run on. */
|
39 |
+
const Device& device() {
|
40 |
+
return stream().device;
|
41 |
+
}
|
42 |
+
|
43 |
+
/** The stream the primitive will run on. */
|
44 |
+
const Stream& stream() {
|
45 |
+
return stream_;
|
46 |
+
}
|
47 |
+
|
48 |
+
/**
|
49 |
+
* A primitive must know how to evaluate itself on
|
50 |
+
* the CPU/GPU for the given inputs and populate the output array.
|
51 |
+
*
|
52 |
+
* To avoid unnecessary allocations, the evaluation function
|
53 |
+
* is responsible for allocating space for the array.
|
54 |
+
*/
|
55 |
+
virtual void eval_cpu(const std::vector<array>& inputs, array& out) = 0;
|
56 |
+
virtual void eval_gpu(const std::vector<array>& inputs, array& out) = 0;
|
57 |
+
|
58 |
+
/**
|
59 |
+
* The Jacobian-vector product.
|
60 |
+
*/
|
61 |
+
virtual array jvp(
|
62 |
+
const std::vector<array>& primals,
|
63 |
+
const std::vector<array>& tangents,
|
64 |
+
const std::vector<int>& argnums);
|
65 |
+
|
66 |
+
/**
|
67 |
+
* The vector-Jacobian product.
|
68 |
+
*/
|
69 |
+
virtual std::vector<array> vjp(
|
70 |
+
const std::vector<array>& primals,
|
71 |
+
const array& cotan,
|
72 |
+
const std::vector<int>& argnums);
|
73 |
+
|
74 |
+
/**
|
75 |
+
* The primitive must know how to vectorize itself across
|
76 |
+
* the given axes. The output is a pair containing the array
|
77 |
+
* representing the vectorized computation and the axis which
|
78 |
+
* corresponds to the output vectorized dimension.
|
79 |
+
*/
|
80 |
+
virtual std::pair<array, int> vmap(
|
81 |
+
const std::vector<array>& inputs,
|
82 |
+
const std::vector<int>& axes);
|
83 |
+
|
84 |
+
/** Print the primitive. */
|
85 |
+
virtual void print(std::ostream& os) = 0;
|
86 |
+
|
87 |
+
/** Equivalence check defaults to false unless overridden by the primitive */
|
88 |
+
virtual bool is_equivalent(const Primitive& other) const {
|
89 |
+
return false;
|
90 |
+
}
|
91 |
+
|
92 |
+
virtual ~Primitive() = default;
|
93 |
+
Primitive(const Primitive& other) = delete;
|
94 |
+
Primitive(Primitive&& other) = delete;
|
95 |
+
Primitive& operator=(const Primitive& other) = delete;
|
96 |
+
Primitive& operator=(Primitive&& other) = delete;
|
97 |
+
|
98 |
+
private:
|
99 |
+
// Every primitive stores the stream it should run in
|
100 |
+
Stream stream_;
|
101 |
+
};
|
102 |
+
|
103 |
+
class Abs : public Primitive {
|
104 |
+
public:
|
105 |
+
explicit Abs(Stream stream) : Primitive(stream){};
|
106 |
+
|
107 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
108 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
109 |
+
|
110 |
+
std::pair<array, int> vmap(
|
111 |
+
const std::vector<array>& inputs,
|
112 |
+
const std::vector<int>& axes) override;
|
113 |
+
|
114 |
+
DEFINE_GRADS()
|
115 |
+
DEFINE_PRINT(Abs)
|
116 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
117 |
+
|
118 |
+
private:
|
119 |
+
void eval(const std::vector<array>& inputs, array& out);
|
120 |
+
};
|
121 |
+
|
122 |
+
class Add : public Primitive {
|
123 |
+
public:
|
124 |
+
explicit Add(Stream stream) : Primitive(stream){};
|
125 |
+
|
126 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
127 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
128 |
+
|
129 |
+
std::pair<array, int> vmap(
|
130 |
+
const std::vector<array>& inputs,
|
131 |
+
const std::vector<int>& axes) override;
|
132 |
+
|
133 |
+
DEFINE_GRADS()
|
134 |
+
DEFINE_PRINT(Add)
|
135 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
136 |
+
|
137 |
+
private:
|
138 |
+
void eval(const std::vector<array>& inputs, array& out);
|
139 |
+
};
|
140 |
+
|
141 |
+
class Arange : public Primitive {
|
142 |
+
public:
|
143 |
+
explicit Arange(Stream stream, double start, double stop, double step)
|
144 |
+
: Primitive(stream), start_(start), stop_(stop), step_(step){};
|
145 |
+
|
146 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
147 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
148 |
+
|
149 |
+
DEFINE_PRINT(Arange)
|
150 |
+
bool is_equivalent(const Primitive& other) const override;
|
151 |
+
|
152 |
+
private:
|
153 |
+
double start_;
|
154 |
+
double stop_;
|
155 |
+
double step_;
|
156 |
+
|
157 |
+
void eval(const std::vector<array>& inputs, array& out);
|
158 |
+
};
|
159 |
+
|
160 |
+
class ArcCos : public Primitive {
|
161 |
+
public:
|
162 |
+
explicit ArcCos(Stream stream) : Primitive(stream){};
|
163 |
+
|
164 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
165 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
166 |
+
|
167 |
+
std::pair<array, int> vmap(
|
168 |
+
const std::vector<array>& inputs,
|
169 |
+
const std::vector<int>& axes) override;
|
170 |
+
|
171 |
+
DEFINE_GRADS()
|
172 |
+
DEFINE_PRINT(ArcCos)
|
173 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
174 |
+
|
175 |
+
private:
|
176 |
+
void eval(const std::vector<array>& inputs, array& out);
|
177 |
+
};
|
178 |
+
|
179 |
+
class ArcCosh : public Primitive {
|
180 |
+
public:
|
181 |
+
explicit ArcCosh(Stream stream) : Primitive(stream){};
|
182 |
+
|
183 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
184 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
185 |
+
|
186 |
+
std::pair<array, int> vmap(
|
187 |
+
const std::vector<array>& inputs,
|
188 |
+
const std::vector<int>& axes) override;
|
189 |
+
|
190 |
+
DEFINE_GRADS()
|
191 |
+
DEFINE_PRINT(ArcCosh)
|
192 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
193 |
+
|
194 |
+
private:
|
195 |
+
void eval(const std::vector<array>& inputs, array& out);
|
196 |
+
};
|
197 |
+
|
198 |
+
class ArcSin : public Primitive {
|
199 |
+
public:
|
200 |
+
explicit ArcSin(Stream stream) : Primitive(stream){};
|
201 |
+
|
202 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
203 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
204 |
+
|
205 |
+
std::pair<array, int> vmap(
|
206 |
+
const std::vector<array>& inputs,
|
207 |
+
const std::vector<int>& axes) override;
|
208 |
+
|
209 |
+
DEFINE_GRADS()
|
210 |
+
DEFINE_PRINT(ArcSin)
|
211 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
212 |
+
|
213 |
+
private:
|
214 |
+
void eval(const std::vector<array>& inputs, array& out);
|
215 |
+
};
|
216 |
+
|
217 |
+
class ArcSinh : public Primitive {
|
218 |
+
public:
|
219 |
+
explicit ArcSinh(Stream stream) : Primitive(stream){};
|
220 |
+
|
221 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
222 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
223 |
+
|
224 |
+
std::pair<array, int> vmap(
|
225 |
+
const std::vector<array>& inputs,
|
226 |
+
const std::vector<int>& axes) override;
|
227 |
+
|
228 |
+
DEFINE_GRADS()
|
229 |
+
DEFINE_PRINT(ArcSinh)
|
230 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
231 |
+
|
232 |
+
private:
|
233 |
+
void eval(const std::vector<array>& inputs, array& out);
|
234 |
+
};
|
235 |
+
|
236 |
+
class ArcTan : public Primitive {
|
237 |
+
public:
|
238 |
+
explicit ArcTan(Stream stream) : Primitive(stream){};
|
239 |
+
|
240 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
241 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
242 |
+
|
243 |
+
std::pair<array, int> vmap(
|
244 |
+
const std::vector<array>& inputs,
|
245 |
+
const std::vector<int>& axes) override;
|
246 |
+
|
247 |
+
DEFINE_GRADS()
|
248 |
+
DEFINE_PRINT(ArcTan)
|
249 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
250 |
+
|
251 |
+
private:
|
252 |
+
void eval(const std::vector<array>& inputs, array& out);
|
253 |
+
};
|
254 |
+
|
255 |
+
class ArcTanh : public Primitive {
|
256 |
+
public:
|
257 |
+
explicit ArcTanh(Stream stream) : Primitive(stream){};
|
258 |
+
|
259 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
260 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
261 |
+
|
262 |
+
std::pair<array, int> vmap(
|
263 |
+
const std::vector<array>& inputs,
|
264 |
+
const std::vector<int>& axes) override;
|
265 |
+
|
266 |
+
DEFINE_GRADS()
|
267 |
+
DEFINE_PRINT(ArcTanh)
|
268 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
269 |
+
|
270 |
+
private:
|
271 |
+
void eval(const std::vector<array>& inputs, array& out);
|
272 |
+
};
|
273 |
+
|
274 |
+
class ArgPartition : public Primitive {
|
275 |
+
public:
|
276 |
+
explicit ArgPartition(Stream stream, int kth, int axis)
|
277 |
+
: Primitive(stream), kth_(kth), axis_(axis){};
|
278 |
+
|
279 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
280 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
281 |
+
|
282 |
+
std::pair<array, int> vmap(
|
283 |
+
const std::vector<array>& inputs,
|
284 |
+
const std::vector<int>& axes) override;
|
285 |
+
|
286 |
+
DEFINE_PRINT(ArgPartition)
|
287 |
+
bool is_equivalent(const Primitive& other) const override;
|
288 |
+
|
289 |
+
private:
|
290 |
+
int kth_;
|
291 |
+
int axis_;
|
292 |
+
|
293 |
+
void eval(const std::vector<array>& inputs, array& out);
|
294 |
+
};
|
295 |
+
|
296 |
+
class ArgReduce : public Primitive {
|
297 |
+
public:
|
298 |
+
enum ReduceType {
|
299 |
+
ArgMin,
|
300 |
+
ArgMax,
|
301 |
+
};
|
302 |
+
|
303 |
+
explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis)
|
304 |
+
: Primitive(stream), reduce_type_(reduce_type), axis_(axis){};
|
305 |
+
|
306 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
307 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
308 |
+
|
309 |
+
DEFINE_PRINT(ArgReduce)
|
310 |
+
bool is_equivalent(const Primitive& other) const override;
|
311 |
+
|
312 |
+
private:
|
313 |
+
ReduceType reduce_type_;
|
314 |
+
int axis_;
|
315 |
+
|
316 |
+
void eval(const std::vector<array>& inputs, array& out);
|
317 |
+
};
|
318 |
+
|
319 |
+
class ArgSort : public Primitive {
|
320 |
+
public:
|
321 |
+
explicit ArgSort(Stream stream, int axis) : Primitive(stream), axis_(axis){};
|
322 |
+
|
323 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
324 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
325 |
+
|
326 |
+
std::pair<array, int> vmap(
|
327 |
+
const std::vector<array>& inputs,
|
328 |
+
const std::vector<int>& axes) override;
|
329 |
+
|
330 |
+
DEFINE_PRINT(ArgSort)
|
331 |
+
bool is_equivalent(const Primitive& other) const override;
|
332 |
+
|
333 |
+
private:
|
334 |
+
int axis_;
|
335 |
+
|
336 |
+
void eval(const std::vector<array>& inputs, array& out);
|
337 |
+
};
|
338 |
+
|
339 |
+
class AsType : public Primitive {
|
340 |
+
public:
|
341 |
+
explicit AsType(Stream stream, Dtype dtype)
|
342 |
+
: Primitive(stream), dtype_(dtype){};
|
343 |
+
|
344 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
345 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
346 |
+
|
347 |
+
std::pair<array, int> vmap(
|
348 |
+
const std::vector<array>& inputs,
|
349 |
+
const std::vector<int>& axes) override;
|
350 |
+
|
351 |
+
DEFINE_GRADS()
|
352 |
+
DEFINE_PRINT(AsType)
|
353 |
+
bool is_equivalent(const Primitive& other) const override;
|
354 |
+
|
355 |
+
private:
|
356 |
+
Dtype dtype_;
|
357 |
+
|
358 |
+
void eval(const std::vector<array>& inputs, array& out);
|
359 |
+
};
|
360 |
+
|
361 |
+
class AsStrided : public Primitive {
|
362 |
+
public:
|
363 |
+
explicit AsStrided(
|
364 |
+
Stream stream,
|
365 |
+
const std::vector<int>& shape,
|
366 |
+
const std::vector<size_t>& strides,
|
367 |
+
size_t offset)
|
368 |
+
: Primitive(stream), shape_(shape), strides_(strides), offset_(offset){};
|
369 |
+
|
370 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
371 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
372 |
+
|
373 |
+
DEFINE_GRADS()
|
374 |
+
DEFINE_PRINT(AsStrided)
|
375 |
+
bool is_equivalent(const Primitive& other) const override;
|
376 |
+
|
377 |
+
private:
|
378 |
+
std::vector<int> shape_;
|
379 |
+
std::vector<size_t> strides_;
|
380 |
+
size_t offset_;
|
381 |
+
|
382 |
+
void eval(const std::vector<array>& inputs, array& out);
|
383 |
+
};
|
384 |
+
|
385 |
+
class Broadcast : public Primitive {
|
386 |
+
public:
|
387 |
+
explicit Broadcast(Stream stream, const std::vector<int>& shape)
|
388 |
+
: Primitive(stream), shape_(shape){};
|
389 |
+
|
390 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
391 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
392 |
+
|
393 |
+
std::pair<array, int> vmap(
|
394 |
+
const std::vector<array>& inputs,
|
395 |
+
const std::vector<int>& axes) override;
|
396 |
+
|
397 |
+
DEFINE_GRADS()
|
398 |
+
DEFINE_PRINT(Broadcast)
|
399 |
+
bool is_equivalent(const Primitive& other) const override;
|
400 |
+
|
401 |
+
private:
|
402 |
+
std::vector<int> shape_;
|
403 |
+
|
404 |
+
void eval(const std::vector<array>& inputs, array& out);
|
405 |
+
};
|
406 |
+
|
407 |
+
class Ceil : public Primitive {
|
408 |
+
public:
|
409 |
+
explicit Ceil(Stream stream) : Primitive(stream){};
|
410 |
+
|
411 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
412 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
413 |
+
|
414 |
+
std::pair<array, int> vmap(
|
415 |
+
const std::vector<array>& inputs,
|
416 |
+
const std::vector<int>& axes) override;
|
417 |
+
|
418 |
+
DEFINE_GRADS()
|
419 |
+
DEFINE_PRINT(Ceil)
|
420 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
421 |
+
|
422 |
+
private:
|
423 |
+
void eval(const std::vector<array>& inputs, array& out);
|
424 |
+
};
|
425 |
+
|
426 |
+
class Concatenate : public Primitive {
|
427 |
+
public:
|
428 |
+
explicit Concatenate(Stream stream, int axis)
|
429 |
+
: Primitive(stream), axis_(axis){};
|
430 |
+
|
431 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
432 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
433 |
+
|
434 |
+
std::pair<array, int> vmap(
|
435 |
+
const std::vector<array>& inputs,
|
436 |
+
const std::vector<int>& axes) override;
|
437 |
+
|
438 |
+
DEFINE_GRADS()
|
439 |
+
DEFINE_PRINT(Concatenate)
|
440 |
+
bool is_equivalent(const Primitive& other) const override;
|
441 |
+
|
442 |
+
private:
|
443 |
+
int axis_;
|
444 |
+
|
445 |
+
void eval(const std::vector<array>& inputs, array& out);
|
446 |
+
};
|
447 |
+
|
448 |
+
class Convolution : public Primitive {
|
449 |
+
public:
|
450 |
+
explicit Convolution(
|
451 |
+
Stream stream,
|
452 |
+
const std::vector<int>& padding,
|
453 |
+
const std::vector<int>& kernel_strides,
|
454 |
+
const std::vector<int>& kernel_dilation,
|
455 |
+
const std::vector<int>& input_dilation)
|
456 |
+
: Primitive(stream),
|
457 |
+
padding_(padding),
|
458 |
+
kernel_strides_(kernel_strides),
|
459 |
+
kernel_dilation_(kernel_dilation),
|
460 |
+
input_dilation_(input_dilation){};
|
461 |
+
|
462 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
463 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
464 |
+
|
465 |
+
std::vector<array> vjp(
|
466 |
+
const std::vector<array>& primals,
|
467 |
+
const array& cotan,
|
468 |
+
const std::vector<int>& argnums) override;
|
469 |
+
|
470 |
+
DEFINE_PRINT(Convolution)
|
471 |
+
bool is_equivalent(const Primitive& other) const override;
|
472 |
+
|
473 |
+
private:
|
474 |
+
std::vector<int> padding_;
|
475 |
+
std::vector<int> kernel_strides_;
|
476 |
+
std::vector<int> kernel_dilation_;
|
477 |
+
std::vector<int> input_dilation_;
|
478 |
+
|
479 |
+
void eval(const std::vector<array>& inputs, array& out);
|
480 |
+
};
|
481 |
+
|
482 |
+
class Copy : public Primitive {
|
483 |
+
public:
|
484 |
+
explicit Copy(Stream stream) : Primitive(stream){};
|
485 |
+
|
486 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
487 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
488 |
+
|
489 |
+
std::pair<array, int> vmap(
|
490 |
+
const std::vector<array>& inputs,
|
491 |
+
const std::vector<int>& axes) override;
|
492 |
+
|
493 |
+
DEFINE_GRADS()
|
494 |
+
DEFINE_PRINT(Copy)
|
495 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
496 |
+
|
497 |
+
private:
|
498 |
+
void eval(const std::vector<array>& inputs, array& out);
|
499 |
+
};
|
500 |
+
|
501 |
+
class Cos : public Primitive {
|
502 |
+
public:
|
503 |
+
explicit Cos(Stream stream) : Primitive(stream){};
|
504 |
+
|
505 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
506 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
507 |
+
|
508 |
+
std::pair<array, int> vmap(
|
509 |
+
const std::vector<array>& inputs,
|
510 |
+
const std::vector<int>& axes) override;
|
511 |
+
|
512 |
+
DEFINE_GRADS()
|
513 |
+
DEFINE_PRINT(Cos)
|
514 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
515 |
+
|
516 |
+
private:
|
517 |
+
void eval(const std::vector<array>& inputs, array& out);
|
518 |
+
};
|
519 |
+
|
520 |
+
class Cosh : public Primitive {
|
521 |
+
public:
|
522 |
+
explicit Cosh(Stream stream) : Primitive(stream){};
|
523 |
+
|
524 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
525 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
526 |
+
|
527 |
+
std::pair<array, int> vmap(
|
528 |
+
const std::vector<array>& inputs,
|
529 |
+
const std::vector<int>& axes) override;
|
530 |
+
|
531 |
+
DEFINE_GRADS()
|
532 |
+
DEFINE_PRINT(Cosh)
|
533 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
534 |
+
|
535 |
+
private:
|
536 |
+
void eval(const std::vector<array>& inputs, array& out);
|
537 |
+
};
|
538 |
+
|
539 |
+
class Divide : public Primitive {
|
540 |
+
public:
|
541 |
+
explicit Divide(Stream stream) : Primitive(stream){};
|
542 |
+
|
543 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
544 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
545 |
+
|
546 |
+
std::pair<array, int> vmap(
|
547 |
+
const std::vector<array>& inputs,
|
548 |
+
const std::vector<int>& axes) override;
|
549 |
+
|
550 |
+
DEFINE_GRADS()
|
551 |
+
DEFINE_PRINT(Divide)
|
552 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
553 |
+
|
554 |
+
private:
|
555 |
+
void eval(const std::vector<array>& inputs, array& out);
|
556 |
+
};
|
557 |
+
|
558 |
+
class Remainder : public Primitive {
|
559 |
+
public:
|
560 |
+
explicit Remainder(Stream stream) : Primitive(stream){};
|
561 |
+
|
562 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
563 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
564 |
+
|
565 |
+
std::pair<array, int> vmap(
|
566 |
+
const std::vector<array>& inputs,
|
567 |
+
const std::vector<int>& axes) override;
|
568 |
+
|
569 |
+
DEFINE_GRADS()
|
570 |
+
DEFINE_PRINT(Remainder)
|
571 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
572 |
+
|
573 |
+
private:
|
574 |
+
void eval(const std::vector<array>& inputs, array& out);
|
575 |
+
};
|
576 |
+
|
577 |
+
class Equal : public Primitive {
|
578 |
+
public:
|
579 |
+
explicit Equal(Stream stream, bool equal_nan = false)
|
580 |
+
: Primitive(stream), equal_nan_(equal_nan){};
|
581 |
+
|
582 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
583 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
584 |
+
|
585 |
+
std::pair<array, int> vmap(
|
586 |
+
const std::vector<array>& inputs,
|
587 |
+
const std::vector<int>& axes) override;
|
588 |
+
|
589 |
+
DEFINE_GRADS()
|
590 |
+
DEFINE_PRINT(Equal)
|
591 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
592 |
+
|
593 |
+
private:
|
594 |
+
void eval(const std::vector<array>& inputs, array& out);
|
595 |
+
bool equal_nan_;
|
596 |
+
};
|
597 |
+
|
598 |
+
class Erf : public Primitive {
|
599 |
+
public:
|
600 |
+
explicit Erf(Stream stream) : Primitive(stream){};
|
601 |
+
|
602 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
603 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
604 |
+
|
605 |
+
std::pair<array, int> vmap(
|
606 |
+
const std::vector<array>& inputs,
|
607 |
+
const std::vector<int>& axes) override;
|
608 |
+
|
609 |
+
DEFINE_GRADS()
|
610 |
+
DEFINE_PRINT(Erf)
|
611 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
612 |
+
|
613 |
+
private:
|
614 |
+
void eval(const std::vector<array>& inputs, array& out);
|
615 |
+
};
|
616 |
+
|
617 |
+
class ErfInv : public Primitive {
|
618 |
+
public:
|
619 |
+
explicit ErfInv(Stream stream) : Primitive(stream){};
|
620 |
+
|
621 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
622 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
623 |
+
|
624 |
+
std::pair<array, int> vmap(
|
625 |
+
const std::vector<array>& inputs,
|
626 |
+
const std::vector<int>& axes) override;
|
627 |
+
|
628 |
+
DEFINE_GRADS()
|
629 |
+
DEFINE_PRINT(ErfInv)
|
630 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
631 |
+
|
632 |
+
private:
|
633 |
+
void eval(const std::vector<array>& inputs, array& out);
|
634 |
+
};
|
635 |
+
|
636 |
+
class Exp : public Primitive {
|
637 |
+
public:
|
638 |
+
explicit Exp(Stream stream) : Primitive(stream){};
|
639 |
+
|
640 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
641 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
642 |
+
|
643 |
+
std::pair<array, int> vmap(
|
644 |
+
const std::vector<array>& inputs,
|
645 |
+
const std::vector<int>& axes) override;
|
646 |
+
|
647 |
+
DEFINE_GRADS()
|
648 |
+
DEFINE_PRINT(Exp)
|
649 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
650 |
+
|
651 |
+
private:
|
652 |
+
void eval(const std::vector<array>& inputs, array& out);
|
653 |
+
};
|
654 |
+
|
655 |
+
class FFT : public Primitive {
|
656 |
+
public:
|
657 |
+
explicit FFT(
|
658 |
+
Stream stream,
|
659 |
+
const std::vector<size_t>& axes,
|
660 |
+
bool inverse,
|
661 |
+
bool real)
|
662 |
+
: Primitive(stream), axes_(axes), inverse_(inverse), real_(real){};
|
663 |
+
|
664 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
665 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
666 |
+
|
667 |
+
std::pair<array, int> vmap(
|
668 |
+
const std::vector<array>& inputs,
|
669 |
+
const std::vector<int>& axes) override;
|
670 |
+
|
671 |
+
DEFINE_GRADS()
|
672 |
+
DEFINE_PRINT(FFT)
|
673 |
+
|
674 |
+
bool is_equivalent(const Primitive& other) const override;
|
675 |
+
|
676 |
+
private:
|
677 |
+
std::vector<size_t> axes_;
|
678 |
+
bool inverse_;
|
679 |
+
bool real_;
|
680 |
+
|
681 |
+
void eval(const std::vector<array>& inputs, array& out);
|
682 |
+
};
|
683 |
+
|
684 |
+
class Floor : public Primitive {
|
685 |
+
public:
|
686 |
+
explicit Floor(Stream stream) : Primitive(stream){};
|
687 |
+
|
688 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
689 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
690 |
+
|
691 |
+
std::pair<array, int> vmap(
|
692 |
+
const std::vector<array>& inputs,
|
693 |
+
const std::vector<int>& axes) override;
|
694 |
+
|
695 |
+
DEFINE_GRADS()
|
696 |
+
DEFINE_PRINT(Floor)
|
697 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
698 |
+
|
699 |
+
private:
|
700 |
+
void eval(const std::vector<array>& inputs, array& out);
|
701 |
+
};
|
702 |
+
|
703 |
+
class Full : public Primitive {
|
704 |
+
public:
|
705 |
+
explicit Full(Stream stream) : Primitive(stream){};
|
706 |
+
|
707 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
708 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
709 |
+
|
710 |
+
std::pair<array, int> vmap(
|
711 |
+
const std::vector<array>& inputs,
|
712 |
+
const std::vector<int>& axes) override;
|
713 |
+
|
714 |
+
DEFINE_GRADS()
|
715 |
+
DEFINE_PRINT(Full)
|
716 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
717 |
+
|
718 |
+
private:
|
719 |
+
void eval(const std::vector<array>& inputs, array& out);
|
720 |
+
};
|
721 |
+
|
722 |
+
class Gather : public Primitive {
|
723 |
+
public:
|
724 |
+
explicit Gather(
|
725 |
+
Stream stream,
|
726 |
+
const std::vector<int>& axes,
|
727 |
+
const std::vector<int>& slice_sizes)
|
728 |
+
: Primitive(stream), axes_(axes), slice_sizes_(slice_sizes){};
|
729 |
+
|
730 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
731 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
732 |
+
|
733 |
+
std::pair<array, int> vmap(
|
734 |
+
const std::vector<array>& inputs,
|
735 |
+
const std::vector<int>& axes) override;
|
736 |
+
|
737 |
+
DEFINE_GRADS()
|
738 |
+
DEFINE_PRINT(Gather)
|
739 |
+
bool is_equivalent(const Primitive& other) const override;
|
740 |
+
|
741 |
+
private:
|
742 |
+
void eval(const std::vector<array>& inputs, array& out);
|
743 |
+
std::vector<int> axes_;
|
744 |
+
std::vector<int> slice_sizes_;
|
745 |
+
};
|
746 |
+
|
747 |
+
class Greater : public Primitive {
|
748 |
+
public:
|
749 |
+
explicit Greater(Stream stream) : Primitive(stream){};
|
750 |
+
|
751 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
752 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
753 |
+
|
754 |
+
std::pair<array, int> vmap(
|
755 |
+
const std::vector<array>& inputs,
|
756 |
+
const std::vector<int>& axes) override;
|
757 |
+
|
758 |
+
DEFINE_GRADS()
|
759 |
+
DEFINE_PRINT(Greater)
|
760 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
761 |
+
|
762 |
+
private:
|
763 |
+
void eval(const std::vector<array>& inputs, array& out);
|
764 |
+
};
|
765 |
+
|
766 |
+
class GreaterEqual : public Primitive {
|
767 |
+
public:
|
768 |
+
explicit GreaterEqual(Stream stream) : Primitive(stream){};
|
769 |
+
|
770 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
771 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
772 |
+
|
773 |
+
std::pair<array, int> vmap(
|
774 |
+
const std::vector<array>& inputs,
|
775 |
+
const std::vector<int>& axes) override;
|
776 |
+
|
777 |
+
DEFINE_GRADS()
|
778 |
+
DEFINE_PRINT(GreaterEqual)
|
779 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
780 |
+
|
781 |
+
private:
|
782 |
+
void eval(const std::vector<array>& inputs, array& out);
|
783 |
+
};
|
784 |
+
|
785 |
+
class Less : public Primitive {
|
786 |
+
public:
|
787 |
+
explicit Less(Stream stream) : Primitive(stream){};
|
788 |
+
|
789 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
790 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
791 |
+
|
792 |
+
std::pair<array, int> vmap(
|
793 |
+
const std::vector<array>& inputs,
|
794 |
+
const std::vector<int>& axes) override;
|
795 |
+
|
796 |
+
DEFINE_GRADS()
|
797 |
+
DEFINE_PRINT(Less)
|
798 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
799 |
+
|
800 |
+
private:
|
801 |
+
void eval(const std::vector<array>& inputs, array& out);
|
802 |
+
};
|
803 |
+
|
804 |
+
class LessEqual : public Primitive {
|
805 |
+
public:
|
806 |
+
explicit LessEqual(Stream stream) : Primitive(stream){};
|
807 |
+
|
808 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
809 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
810 |
+
|
811 |
+
std::pair<array, int> vmap(
|
812 |
+
const std::vector<array>& inputs,
|
813 |
+
const std::vector<int>& axes) override;
|
814 |
+
|
815 |
+
DEFINE_GRADS()
|
816 |
+
DEFINE_PRINT(LessEqual)
|
817 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
818 |
+
|
819 |
+
private:
|
820 |
+
void eval(const std::vector<array>& inputs, array& out);
|
821 |
+
};
|
822 |
+
|
823 |
+
class Load : public Primitive {
|
824 |
+
public:
|
825 |
+
explicit Load(
|
826 |
+
Stream stream,
|
827 |
+
std::shared_ptr<io::Reader> reader,
|
828 |
+
size_t offset,
|
829 |
+
bool swap_endianness = false)
|
830 |
+
: Primitive(stream),
|
831 |
+
reader_(reader),
|
832 |
+
offset_(offset),
|
833 |
+
swap_endianness_(swap_endianness){};
|
834 |
+
|
835 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
836 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
837 |
+
|
838 |
+
DEFINE_PRINT(Load)
|
839 |
+
|
840 |
+
private:
|
841 |
+
void eval(const std::vector<array>& inputs, array& out);
|
842 |
+
std::shared_ptr<io::Reader> reader_;
|
843 |
+
size_t offset_;
|
844 |
+
bool swap_endianness_;
|
845 |
+
};
|
846 |
+
|
847 |
+
class Log : public Primitive {
|
848 |
+
public:
|
849 |
+
enum Base { two, ten, e };
|
850 |
+
|
851 |
+
explicit Log(Stream stream, Base base) : Primitive(stream), base_(base){};
|
852 |
+
|
853 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
854 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
855 |
+
|
856 |
+
std::pair<array, int> vmap(
|
857 |
+
const std::vector<array>& inputs,
|
858 |
+
const std::vector<int>& axes) override;
|
859 |
+
|
860 |
+
DEFINE_GRADS()
|
861 |
+
DEFINE_PRINT(Log)
|
862 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
863 |
+
|
864 |
+
private:
|
865 |
+
Base base_;
|
866 |
+
void eval(const std::vector<array>& inputs, array& out);
|
867 |
+
};
|
868 |
+
|
869 |
+
class Log1p : public Primitive {
|
870 |
+
public:
|
871 |
+
explicit Log1p(Stream stream) : Primitive(stream){};
|
872 |
+
|
873 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
874 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
875 |
+
|
876 |
+
std::pair<array, int> vmap(
|
877 |
+
const std::vector<array>& inputs,
|
878 |
+
const std::vector<int>& axes) override;
|
879 |
+
|
880 |
+
DEFINE_GRADS()
|
881 |
+
DEFINE_PRINT(Log1p)
|
882 |
+
|
883 |
+
private:
|
884 |
+
void eval(const std::vector<array>& inputs, array& out);
|
885 |
+
};
|
886 |
+
|
887 |
+
class LogicalNot : public Primitive {
|
888 |
+
public:
|
889 |
+
explicit LogicalNot(Stream stream) : Primitive(stream){};
|
890 |
+
|
891 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
892 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
893 |
+
|
894 |
+
std::pair<array, int> vmap(
|
895 |
+
const std::vector<array>& inputs,
|
896 |
+
const std::vector<int>& axes) override;
|
897 |
+
|
898 |
+
DEFINE_GRADS()
|
899 |
+
DEFINE_PRINT(LogicalNot)
|
900 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
901 |
+
|
902 |
+
private:
|
903 |
+
void eval(const std::vector<array>& inputs, array& out);
|
904 |
+
};
|
905 |
+
|
906 |
+
class LogAddExp : public Primitive {
|
907 |
+
public:
|
908 |
+
explicit LogAddExp(Stream stream) : Primitive(stream){};
|
909 |
+
|
910 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
911 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
912 |
+
|
913 |
+
std::pair<array, int> vmap(
|
914 |
+
const std::vector<array>& inputs,
|
915 |
+
const std::vector<int>& axes) override;
|
916 |
+
|
917 |
+
DEFINE_GRADS()
|
918 |
+
DEFINE_PRINT(LogAddExp)
|
919 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
920 |
+
|
921 |
+
private:
|
922 |
+
void eval(const std::vector<array>& inputs, array& out);
|
923 |
+
};
|
924 |
+
|
925 |
+
class Matmul : public Primitive {
|
926 |
+
public:
|
927 |
+
explicit Matmul(Stream stream) : Primitive(stream){};
|
928 |
+
|
929 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
930 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
931 |
+
|
932 |
+
std::vector<array> vjp(
|
933 |
+
const std::vector<array>& primals,
|
934 |
+
const array& cotan,
|
935 |
+
const std::vector<int>& argnums) override;
|
936 |
+
|
937 |
+
std::pair<array, int> vmap(
|
938 |
+
const std::vector<array>& inputs,
|
939 |
+
const std::vector<int>& axes) override;
|
940 |
+
|
941 |
+
DEFINE_PRINT(Matmul)
|
942 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
943 |
+
};
|
944 |
+
|
945 |
+
class Maximum : public Primitive {
|
946 |
+
public:
|
947 |
+
explicit Maximum(Stream stream) : Primitive(stream){};
|
948 |
+
|
949 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
950 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
951 |
+
|
952 |
+
std::pair<array, int> vmap(
|
953 |
+
const std::vector<array>& inputs,
|
954 |
+
const std::vector<int>& axes) override;
|
955 |
+
|
956 |
+
DEFINE_GRADS()
|
957 |
+
DEFINE_PRINT(Maximum)
|
958 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
959 |
+
|
960 |
+
private:
|
961 |
+
void eval(const std::vector<array>& inputs, array& out);
|
962 |
+
};
|
963 |
+
|
964 |
+
class Minimum : public Primitive {
|
965 |
+
public:
|
966 |
+
explicit Minimum(Stream stream) : Primitive(stream){};
|
967 |
+
|
968 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
969 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
970 |
+
|
971 |
+
std::pair<array, int> vmap(
|
972 |
+
const std::vector<array>& inputs,
|
973 |
+
const std::vector<int>& axes) override;
|
974 |
+
|
975 |
+
DEFINE_GRADS()
|
976 |
+
DEFINE_PRINT(Minimum)
|
977 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
978 |
+
|
979 |
+
private:
|
980 |
+
void eval(const std::vector<array>& inputs, array& out);
|
981 |
+
};
|
982 |
+
|
983 |
+
class Multiply : public Primitive {
|
984 |
+
public:
|
985 |
+
explicit Multiply(Stream stream) : Primitive(stream){};
|
986 |
+
|
987 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
988 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
989 |
+
|
990 |
+
std::pair<array, int> vmap(
|
991 |
+
const std::vector<array>& inputs,
|
992 |
+
const std::vector<int>& axes) override;
|
993 |
+
|
994 |
+
DEFINE_GRADS()
|
995 |
+
DEFINE_PRINT(Multiply)
|
996 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
997 |
+
|
998 |
+
private:
|
999 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1000 |
+
};
|
1001 |
+
|
1002 |
+
class Negative : public Primitive {
|
1003 |
+
public:
|
1004 |
+
explicit Negative(Stream stream) : Primitive(stream){};
|
1005 |
+
|
1006 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1007 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1008 |
+
|
1009 |
+
std::pair<array, int> vmap(
|
1010 |
+
const std::vector<array>& inputs,
|
1011 |
+
const std::vector<int>& axes) override;
|
1012 |
+
|
1013 |
+
DEFINE_GRADS()
|
1014 |
+
DEFINE_PRINT(Negative)
|
1015 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
1016 |
+
|
1017 |
+
private:
|
1018 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1019 |
+
};
|
1020 |
+
|
1021 |
+
class NotEqual : public Primitive {
|
1022 |
+
public:
|
1023 |
+
explicit NotEqual(Stream stream) : Primitive(stream){};
|
1024 |
+
|
1025 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1026 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1027 |
+
|
1028 |
+
std::pair<array, int> vmap(
|
1029 |
+
const std::vector<array>& inputs,
|
1030 |
+
const std::vector<int>& axes) override;
|
1031 |
+
|
1032 |
+
DEFINE_GRADS()
|
1033 |
+
DEFINE_PRINT(NotEqual)
|
1034 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
1035 |
+
|
1036 |
+
private:
|
1037 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1038 |
+
};
|
1039 |
+
|
1040 |
+
class Pad : public Primitive {
|
1041 |
+
public:
|
1042 |
+
explicit Pad(
|
1043 |
+
Stream stream,
|
1044 |
+
const std::vector<int>& axes,
|
1045 |
+
const std::vector<int>& low_pad_size,
|
1046 |
+
const std::vector<int>& high_pad_size)
|
1047 |
+
: Primitive(stream),
|
1048 |
+
axes_(axes),
|
1049 |
+
low_pad_size_(low_pad_size),
|
1050 |
+
high_pad_size_(high_pad_size){};
|
1051 |
+
|
1052 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1053 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1054 |
+
|
1055 |
+
std::pair<array, int> vmap(
|
1056 |
+
const std::vector<array>& inputs,
|
1057 |
+
const std::vector<int>& axes) override;
|
1058 |
+
|
1059 |
+
DEFINE_GRADS()
|
1060 |
+
DEFINE_PRINT(Pad)
|
1061 |
+
bool is_equivalent(const Primitive& other) const override;
|
1062 |
+
|
1063 |
+
private:
|
1064 |
+
std::vector<int> axes_;
|
1065 |
+
std::vector<int> low_pad_size_;
|
1066 |
+
std::vector<int> high_pad_size_;
|
1067 |
+
|
1068 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1069 |
+
};
|
1070 |
+
|
1071 |
+
class Partition : public Primitive {
|
1072 |
+
public:
|
1073 |
+
explicit Partition(Stream stream, int kth, int axis)
|
1074 |
+
: Primitive(stream), kth_(kth), axis_(axis){};
|
1075 |
+
|
1076 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1077 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1078 |
+
|
1079 |
+
std::pair<array, int> vmap(
|
1080 |
+
const std::vector<array>& inputs,
|
1081 |
+
const std::vector<int>& axes) override;
|
1082 |
+
|
1083 |
+
DEFINE_GRADS()
|
1084 |
+
DEFINE_PRINT(Partition)
|
1085 |
+
bool is_equivalent(const Primitive& other) const override;
|
1086 |
+
|
1087 |
+
private:
|
1088 |
+
int kth_;
|
1089 |
+
int axis_;
|
1090 |
+
|
1091 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1092 |
+
};
|
1093 |
+
|
1094 |
+
class Power : public Primitive {
|
1095 |
+
public:
|
1096 |
+
explicit Power(Stream stream) : Primitive(stream){};
|
1097 |
+
|
1098 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1099 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1100 |
+
|
1101 |
+
std::pair<array, int> vmap(
|
1102 |
+
const std::vector<array>& inputs,
|
1103 |
+
const std::vector<int>& axes) override;
|
1104 |
+
|
1105 |
+
DEFINE_GRADS()
|
1106 |
+
DEFINE_PRINT(Power)
|
1107 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
1108 |
+
|
1109 |
+
private:
|
1110 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1111 |
+
};
|
1112 |
+
|
1113 |
+
class QuantizedMatmul : public Primitive {
|
1114 |
+
public:
|
1115 |
+
explicit QuantizedMatmul(
|
1116 |
+
Stream stream,
|
1117 |
+
int group_size,
|
1118 |
+
int bits,
|
1119 |
+
bool transpose)
|
1120 |
+
: Primitive(stream),
|
1121 |
+
group_size_(group_size),
|
1122 |
+
bits_(bits),
|
1123 |
+
transpose_(transpose){};
|
1124 |
+
|
1125 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1126 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1127 |
+
|
1128 |
+
std::pair<array, int> vmap(
|
1129 |
+
const std::vector<array>& inputs,
|
1130 |
+
const std::vector<int>& axes) override;
|
1131 |
+
|
1132 |
+
DEFINE_GRADS()
|
1133 |
+
DEFINE_PRINT(QuantizedMatmul)
|
1134 |
+
bool is_equivalent(const Primitive& other) const override;
|
1135 |
+
|
1136 |
+
private:
|
1137 |
+
int group_size_;
|
1138 |
+
int bits_;
|
1139 |
+
bool transpose_;
|
1140 |
+
|
1141 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1142 |
+
};
|
1143 |
+
|
1144 |
+
class RandomBits : public Primitive {
|
1145 |
+
public:
|
1146 |
+
explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)
|
1147 |
+
: Primitive(stream), shape_(shape), width_(width){};
|
1148 |
+
|
1149 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1150 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1151 |
+
|
1152 |
+
std::pair<array, int> vmap(
|
1153 |
+
const std::vector<array>& inputs,
|
1154 |
+
const std::vector<int>& axes) override;
|
1155 |
+
|
1156 |
+
DEFINE_PRINT(RandomBits)
|
1157 |
+
bool is_equivalent(const Primitive& other) const override;
|
1158 |
+
|
1159 |
+
private:
|
1160 |
+
std::vector<int> shape_;
|
1161 |
+
int width_;
|
1162 |
+
|
1163 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1164 |
+
};
|
1165 |
+
|
1166 |
+
class Reshape : public Primitive {
|
1167 |
+
public:
|
1168 |
+
explicit Reshape(Stream stream, const std::vector<int>& shape)
|
1169 |
+
: Primitive(stream), shape_(shape){};
|
1170 |
+
|
1171 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1172 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1173 |
+
|
1174 |
+
std::pair<array, int> vmap(
|
1175 |
+
const std::vector<array>& inputs,
|
1176 |
+
const std::vector<int>& axes) override;
|
1177 |
+
|
1178 |
+
DEFINE_GRADS()
|
1179 |
+
DEFINE_PRINT(Reshape)
|
1180 |
+
bool is_equivalent(const Primitive& other) const override;
|
1181 |
+
|
1182 |
+
private:
|
1183 |
+
std::vector<int> shape_;
|
1184 |
+
|
1185 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1186 |
+
};
|
1187 |
+
|
1188 |
+
class Reduce : public Primitive {
|
1189 |
+
public:
|
1190 |
+
enum ReduceType { And, Or, Sum, Prod, Min, Max };
|
1191 |
+
|
1192 |
+
explicit Reduce(
|
1193 |
+
Stream stream,
|
1194 |
+
ReduceType reduce_type,
|
1195 |
+
const std::vector<int>& axes)
|
1196 |
+
: Primitive(stream), reduce_type_(reduce_type), axes_(axes){};
|
1197 |
+
|
1198 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1199 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1200 |
+
|
1201 |
+
std::pair<array, int> vmap(
|
1202 |
+
const std::vector<array>& inputs,
|
1203 |
+
const std::vector<int>& axes) override;
|
1204 |
+
std::vector<array> vjp(
|
1205 |
+
const std::vector<array>& primals,
|
1206 |
+
const array& cotan,
|
1207 |
+
const std::vector<int>& argnums) override;
|
1208 |
+
|
1209 |
+
void print(std::ostream& os) override {
|
1210 |
+
switch (reduce_type_) {
|
1211 |
+
case And:
|
1212 |
+
os << "And";
|
1213 |
+
case Or:
|
1214 |
+
os << "And";
|
1215 |
+
break;
|
1216 |
+
case Sum:
|
1217 |
+
os << "Sum";
|
1218 |
+
break;
|
1219 |
+
case Prod:
|
1220 |
+
os << "Prod";
|
1221 |
+
break;
|
1222 |
+
case Min:
|
1223 |
+
os << "Min";
|
1224 |
+
break;
|
1225 |
+
case Max:
|
1226 |
+
os << "Max";
|
1227 |
+
break;
|
1228 |
+
}
|
1229 |
+
os << " Reduce";
|
1230 |
+
}
|
1231 |
+
bool is_equivalent(const Primitive& other) const override;
|
1232 |
+
|
1233 |
+
private:
|
1234 |
+
ReduceType reduce_type_;
|
1235 |
+
std::vector<int> axes_;
|
1236 |
+
|
1237 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1238 |
+
};
|
1239 |
+
|
1240 |
+
class Round : public Primitive {
|
1241 |
+
public:
|
1242 |
+
explicit Round(Stream stream) : Primitive(stream){};
|
1243 |
+
|
1244 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1245 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1246 |
+
|
1247 |
+
std::pair<array, int> vmap(
|
1248 |
+
const std::vector<array>& inputs,
|
1249 |
+
const std::vector<int>& axes) override;
|
1250 |
+
|
1251 |
+
DEFINE_GRADS()
|
1252 |
+
DEFINE_PRINT(Round)
|
1253 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
1254 |
+
|
1255 |
+
private:
|
1256 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1257 |
+
};
|
1258 |
+
|
1259 |
+
class Scan : public Primitive {
|
1260 |
+
public:
|
1261 |
+
enum ReduceType { Max, Min, Sum, Prod };
|
1262 |
+
|
1263 |
+
explicit Scan(
|
1264 |
+
Stream stream,
|
1265 |
+
ReduceType reduce_type,
|
1266 |
+
int axis,
|
1267 |
+
bool reverse,
|
1268 |
+
bool inclusive)
|
1269 |
+
: Primitive(stream),
|
1270 |
+
reduce_type_(reduce_type),
|
1271 |
+
axis_(axis),
|
1272 |
+
reverse_(reverse),
|
1273 |
+
inclusive_(inclusive){};
|
1274 |
+
|
1275 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1276 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1277 |
+
|
1278 |
+
std::pair<array, int> vmap(
|
1279 |
+
const std::vector<array>& inputs,
|
1280 |
+
const std::vector<int>& axes) override;
|
1281 |
+
|
1282 |
+
DEFINE_GRADS();
|
1283 |
+
void print(std::ostream& os) override {
|
1284 |
+
os << "Cum";
|
1285 |
+
switch (reduce_type_) {
|
1286 |
+
case Sum:
|
1287 |
+
os << "Sum";
|
1288 |
+
break;
|
1289 |
+
case Prod:
|
1290 |
+
os << "Prod";
|
1291 |
+
break;
|
1292 |
+
case Min:
|
1293 |
+
os << "Min";
|
1294 |
+
break;
|
1295 |
+
case Max:
|
1296 |
+
os << "Max";
|
1297 |
+
break;
|
1298 |
+
}
|
1299 |
+
os << " Reduce";
|
1300 |
+
}
|
1301 |
+
bool is_equivalent(const Primitive& other) const override;
|
1302 |
+
|
1303 |
+
private:
|
1304 |
+
ReduceType reduce_type_;
|
1305 |
+
int axis_;
|
1306 |
+
bool reverse_;
|
1307 |
+
bool inclusive_;
|
1308 |
+
|
1309 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1310 |
+
};
|
1311 |
+
|
1312 |
+
class Scatter : public Primitive {
|
1313 |
+
public:
|
1314 |
+
enum ReduceType { Max, Min, Sum, Prod, None };
|
1315 |
+
|
1316 |
+
explicit Scatter(
|
1317 |
+
Stream stream,
|
1318 |
+
ReduceType reduce_type,
|
1319 |
+
const std::vector<int>& axes)
|
1320 |
+
: Primitive(stream), reduce_type_(reduce_type), axes_(axes){};
|
1321 |
+
|
1322 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1323 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1324 |
+
|
1325 |
+
DEFINE_PRINT(Scatter)
|
1326 |
+
bool is_equivalent(const Primitive& other) const override;
|
1327 |
+
|
1328 |
+
private:
|
1329 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1330 |
+
ReduceType reduce_type_;
|
1331 |
+
std::vector<int> axes_;
|
1332 |
+
};
|
1333 |
+
|
1334 |
+
class Sigmoid : public Primitive {
|
1335 |
+
public:
|
1336 |
+
explicit Sigmoid(Stream stream) : Primitive(stream){};
|
1337 |
+
|
1338 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1339 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1340 |
+
|
1341 |
+
std::pair<array, int> vmap(
|
1342 |
+
const std::vector<array>& inputs,
|
1343 |
+
const std::vector<int>& axes) override;
|
1344 |
+
|
1345 |
+
DEFINE_GRADS()
|
1346 |
+
DEFINE_PRINT(Sigmoid)
|
1347 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
1348 |
+
|
1349 |
+
private:
|
1350 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1351 |
+
};
|
1352 |
+
|
1353 |
+
class Sign : public Primitive {
|
1354 |
+
public:
|
1355 |
+
explicit Sign(Stream stream) : Primitive(stream){};
|
1356 |
+
|
1357 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1358 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1359 |
+
|
1360 |
+
std::pair<array, int> vmap(
|
1361 |
+
const std::vector<array>& inputs,
|
1362 |
+
const std::vector<int>& axes) override;
|
1363 |
+
|
1364 |
+
DEFINE_GRADS()
|
1365 |
+
DEFINE_PRINT(Sign)
|
1366 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
1367 |
+
|
1368 |
+
private:
|
1369 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1370 |
+
};
|
1371 |
+
|
1372 |
+
class Sin : public Primitive {
|
1373 |
+
public:
|
1374 |
+
explicit Sin(Stream stream) : Primitive(stream){};
|
1375 |
+
|
1376 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1377 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1378 |
+
|
1379 |
+
std::pair<array, int> vmap(
|
1380 |
+
const std::vector<array>& inputs,
|
1381 |
+
const std::vector<int>& axes) override;
|
1382 |
+
|
1383 |
+
DEFINE_GRADS()
|
1384 |
+
DEFINE_PRINT(Sin)
|
1385 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
1386 |
+
|
1387 |
+
private:
|
1388 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1389 |
+
};
|
1390 |
+
|
1391 |
+
class Sinh : public Primitive {
|
1392 |
+
public:
|
1393 |
+
explicit Sinh(Stream stream) : Primitive(stream){};
|
1394 |
+
|
1395 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1396 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1397 |
+
|
1398 |
+
std::pair<array, int> vmap(
|
1399 |
+
const std::vector<array>& inputs,
|
1400 |
+
const std::vector<int>& axes) override;
|
1401 |
+
|
1402 |
+
DEFINE_GRADS()
|
1403 |
+
DEFINE_PRINT(Sinh)
|
1404 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
1405 |
+
|
1406 |
+
private:
|
1407 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1408 |
+
};
|
1409 |
+
|
1410 |
+
class Slice : public Primitive {
|
1411 |
+
public:
|
1412 |
+
explicit Slice(
|
1413 |
+
Stream stream,
|
1414 |
+
const std::vector<int>& start_indices,
|
1415 |
+
const std::vector<int>& end_indices,
|
1416 |
+
const std::vector<int>& strides)
|
1417 |
+
: Primitive(stream),
|
1418 |
+
start_indices_(start_indices),
|
1419 |
+
end_indices_(end_indices),
|
1420 |
+
strides_(strides){};
|
1421 |
+
|
1422 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1423 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1424 |
+
|
1425 |
+
std::pair<array, int> vmap(
|
1426 |
+
const std::vector<array>& inputs,
|
1427 |
+
const std::vector<int>& axes) override;
|
1428 |
+
|
1429 |
+
DEFINE_GRADS()
|
1430 |
+
DEFINE_PRINT(Slice)
|
1431 |
+
bool is_equivalent(const Primitive& other) const override;
|
1432 |
+
|
1433 |
+
private:
|
1434 |
+
std::vector<int> start_indices_;
|
1435 |
+
std::vector<int> end_indices_;
|
1436 |
+
std::vector<int> strides_;
|
1437 |
+
|
1438 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1439 |
+
};
|
1440 |
+
|
1441 |
+
class Softmax : public Primitive {
|
1442 |
+
public:
|
1443 |
+
explicit Softmax(Stream stream) : Primitive(stream){};
|
1444 |
+
|
1445 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1446 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1447 |
+
|
1448 |
+
std::pair<array, int> vmap(
|
1449 |
+
const std::vector<array>& inputs,
|
1450 |
+
const std::vector<int>& axes) override;
|
1451 |
+
|
1452 |
+
DEFINE_GRADS()
|
1453 |
+
DEFINE_PRINT(Softmax)
|
1454 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
1455 |
+
|
1456 |
+
private:
|
1457 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1458 |
+
};
|
1459 |
+
|
1460 |
+
class Sort : public Primitive {
|
1461 |
+
public:
|
1462 |
+
explicit Sort(Stream stream, int axis) : Primitive(stream), axis_(axis){};
|
1463 |
+
|
1464 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1465 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1466 |
+
|
1467 |
+
std::pair<array, int> vmap(
|
1468 |
+
const std::vector<array>& inputs,
|
1469 |
+
const std::vector<int>& axes) override;
|
1470 |
+
|
1471 |
+
DEFINE_GRADS()
|
1472 |
+
DEFINE_PRINT(Sort)
|
1473 |
+
bool is_equivalent(const Primitive& other) const override;
|
1474 |
+
|
1475 |
+
private:
|
1476 |
+
int axis_;
|
1477 |
+
|
1478 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1479 |
+
};
|
1480 |
+
|
1481 |
+
class Square : public Primitive {
|
1482 |
+
public:
|
1483 |
+
explicit Square(Stream stream) : Primitive(stream){};
|
1484 |
+
|
1485 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1486 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1487 |
+
|
1488 |
+
std::pair<array, int> vmap(
|
1489 |
+
const std::vector<array>& inputs,
|
1490 |
+
const std::vector<int>& axes) override;
|
1491 |
+
|
1492 |
+
DEFINE_GRADS()
|
1493 |
+
DEFINE_PRINT(Square)
|
1494 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
1495 |
+
|
1496 |
+
private:
|
1497 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1498 |
+
};
|
1499 |
+
|
1500 |
+
class Sqrt : public Primitive {
|
1501 |
+
public:
|
1502 |
+
explicit Sqrt(Stream stream, bool recip = false)
|
1503 |
+
: Primitive(stream), recip_(recip){};
|
1504 |
+
|
1505 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1506 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1507 |
+
|
1508 |
+
std::pair<array, int> vmap(
|
1509 |
+
const std::vector<array>& inputs,
|
1510 |
+
const std::vector<int>& axes) override;
|
1511 |
+
|
1512 |
+
DEFINE_GRADS()
|
1513 |
+
DEFINE_PRINT(Sqrt)
|
1514 |
+
bool is_equivalent(const Primitive& other) const override;
|
1515 |
+
|
1516 |
+
private:
|
1517 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1518 |
+
bool recip_;
|
1519 |
+
};
|
1520 |
+
|
1521 |
+
class StopGradient : public Primitive {
|
1522 |
+
public:
|
1523 |
+
explicit StopGradient(Stream stream) : Primitive(stream){};
|
1524 |
+
|
1525 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1526 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1527 |
+
|
1528 |
+
std::pair<array, int> vmap(
|
1529 |
+
const std::vector<array>& inputs,
|
1530 |
+
const std::vector<int>& axes) override;
|
1531 |
+
|
1532 |
+
DEFINE_PRINT(StopGradient)
|
1533 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
1534 |
+
|
1535 |
+
private:
|
1536 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1537 |
+
};
|
1538 |
+
|
1539 |
+
class Subtract : public Primitive {
|
1540 |
+
public:
|
1541 |
+
explicit Subtract(Stream stream) : Primitive(stream){};
|
1542 |
+
|
1543 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1544 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1545 |
+
|
1546 |
+
std::pair<array, int> vmap(
|
1547 |
+
const std::vector<array>& inputs,
|
1548 |
+
const std::vector<int>& axes) override;
|
1549 |
+
|
1550 |
+
DEFINE_GRADS()
|
1551 |
+
DEFINE_PRINT(Subtract)
|
1552 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
1553 |
+
|
1554 |
+
private:
|
1555 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1556 |
+
};
|
1557 |
+
|
1558 |
+
class Tan : public Primitive {
|
1559 |
+
public:
|
1560 |
+
explicit Tan(Stream stream) : Primitive(stream){};
|
1561 |
+
|
1562 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1563 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1564 |
+
|
1565 |
+
std::pair<array, int> vmap(
|
1566 |
+
const std::vector<array>& inputs,
|
1567 |
+
const std::vector<int>& axes) override;
|
1568 |
+
|
1569 |
+
DEFINE_GRADS()
|
1570 |
+
DEFINE_PRINT(Tan)
|
1571 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
1572 |
+
|
1573 |
+
private:
|
1574 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1575 |
+
};
|
1576 |
+
|
1577 |
+
class Tanh : public Primitive {
|
1578 |
+
public:
|
1579 |
+
explicit Tanh(Stream stream) : Primitive(stream){};
|
1580 |
+
|
1581 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1582 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1583 |
+
|
1584 |
+
std::pair<array, int> vmap(
|
1585 |
+
const std::vector<array>& inputs,
|
1586 |
+
const std::vector<int>& axes) override;
|
1587 |
+
|
1588 |
+
DEFINE_GRADS()
|
1589 |
+
DEFINE_PRINT(Tanh)
|
1590 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
1591 |
+
|
1592 |
+
private:
|
1593 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1594 |
+
};
|
1595 |
+
|
1596 |
+
class Uniform : public Primitive {
|
1597 |
+
public:
|
1598 |
+
explicit Uniform(Stream stream) : Primitive(stream){};
|
1599 |
+
|
1600 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1601 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1602 |
+
|
1603 |
+
std::pair<array, int> vmap(
|
1604 |
+
const std::vector<array>& inputs,
|
1605 |
+
const std::vector<int>& axes) override;
|
1606 |
+
|
1607 |
+
DEFINE_PRINT(Uniform)
|
1608 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
1609 |
+
|
1610 |
+
private:
|
1611 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1612 |
+
};
|
1613 |
+
|
1614 |
+
class Transpose : public Primitive {
|
1615 |
+
public:
|
1616 |
+
explicit Transpose(Stream stream, const std::vector<int>& axes)
|
1617 |
+
: Primitive(stream), axes_(axes){};
|
1618 |
+
|
1619 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
1620 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
1621 |
+
|
1622 |
+
std::pair<array, int> vmap(
|
1623 |
+
const std::vector<array>& inputs,
|
1624 |
+
const std::vector<int>& axes) override;
|
1625 |
+
|
1626 |
+
DEFINE_GRADS()
|
1627 |
+
DEFINE_PRINT(Transpose)
|
1628 |
+
bool is_equivalent(const Primitive& other) const override;
|
1629 |
+
|
1630 |
+
private:
|
1631 |
+
std::vector<int> axes_;
|
1632 |
+
|
1633 |
+
void eval(const std::vector<array>& inputs, array& out);
|
1634 |
+
};
|
1635 |
+
|
1636 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/random.h
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include <optional>
|
6 |
+
|
7 |
+
#include "mlx/array.h"
|
8 |
+
#include "mlx/stream.h"
|
9 |
+
|
10 |
+
namespace mlx::core::random {
|
11 |
+
|
12 |
+
class KeySequence {
|
13 |
+
public:
|
14 |
+
explicit KeySequence(uint64_t seed);
|
15 |
+
|
16 |
+
void seed(uint64_t seed);
|
17 |
+
array next();
|
18 |
+
|
19 |
+
// static default
|
20 |
+
static KeySequence& default_() {
|
21 |
+
static KeySequence ks(0);
|
22 |
+
return ks;
|
23 |
+
}
|
24 |
+
|
25 |
+
private:
|
26 |
+
array key_;
|
27 |
+
};
|
28 |
+
|
29 |
+
/** Get a PRNG key from a seed. */
|
30 |
+
array key(uint64_t seed);
|
31 |
+
|
32 |
+
/** Seed the default PRNG key. */
|
33 |
+
void seed(uint64_t seed);
|
34 |
+
|
35 |
+
/** Generate an array with type uint32 filled with random bits. */
|
36 |
+
array bits(
|
37 |
+
const std::vector<int>& shape,
|
38 |
+
int width,
|
39 |
+
const std::optional<array>& key = std::nullopt,
|
40 |
+
StreamOrDevice s = {});
|
41 |
+
inline array bits(
|
42 |
+
const std::vector<int>& shape,
|
43 |
+
const std::optional<array>& key = std::nullopt,
|
44 |
+
StreamOrDevice s = {}) {
|
45 |
+
return bits(shape, 4, key, s);
|
46 |
+
}
|
47 |
+
|
48 |
+
/** Split the rng key into a pair of keys. */
|
49 |
+
std::pair<array, array> split(const array& key, StreamOrDevice s = {});
|
50 |
+
|
51 |
+
/** Split the rng key into `num` keys. */
|
52 |
+
array split(const array& key, int num, StreamOrDevice s = {});
|
53 |
+
|
54 |
+
/** Generate uniform random numbers between low and high. */
|
55 |
+
array uniform(
|
56 |
+
const array& low,
|
57 |
+
const array& high,
|
58 |
+
const std::vector<int>& shape,
|
59 |
+
Dtype dtype = float32,
|
60 |
+
const std::optional<array>& key = std::nullopt,
|
61 |
+
StreamOrDevice s = {});
|
62 |
+
|
63 |
+
template <typename T, typename U>
|
64 |
+
array uniform(
|
65 |
+
T low,
|
66 |
+
U high,
|
67 |
+
const std::vector<int>& shape,
|
68 |
+
Dtype dtype = float32,
|
69 |
+
const std::optional<array>& key = std::nullopt,
|
70 |
+
StreamOrDevice s = {}) {
|
71 |
+
return uniform(array(low), array(high), shape, dtype, key, to_stream(s));
|
72 |
+
}
|
73 |
+
|
74 |
+
/** Generate uniform random numbers between 0 and 1. */
|
75 |
+
array uniform(
|
76 |
+
const std::vector<int>& shape,
|
77 |
+
Dtype dtype,
|
78 |
+
const std::optional<array>& key = std::nullopt,
|
79 |
+
StreamOrDevice s = {});
|
80 |
+
inline array uniform(
|
81 |
+
const std::vector<int>& shape,
|
82 |
+
const std::optional<array>& key = std::nullopt,
|
83 |
+
StreamOrDevice s = {}) {
|
84 |
+
return uniform(shape, float32, key);
|
85 |
+
}
|
86 |
+
|
87 |
+
/** Generate samples from the standard normal distribution. */
|
88 |
+
array normal(
|
89 |
+
const std::vector<int>& shape,
|
90 |
+
Dtype dtype,
|
91 |
+
const std::optional<array>& key = std::nullopt,
|
92 |
+
StreamOrDevice s = {});
|
93 |
+
inline array normal(
|
94 |
+
const std::vector<int>& shape,
|
95 |
+
const std::optional<array>& key = std::nullopt,
|
96 |
+
StreamOrDevice s = {}) {
|
97 |
+
return normal(shape, float32, key, s);
|
98 |
+
}
|
99 |
+
|
100 |
+
/** Generate integer samples uniformly at random */
|
101 |
+
array randint(
|
102 |
+
const array& low,
|
103 |
+
const array& high,
|
104 |
+
const std::vector<int>& shape,
|
105 |
+
Dtype dtype = int32,
|
106 |
+
const std::optional<array>& key = std::nullopt,
|
107 |
+
StreamOrDevice s = {});
|
108 |
+
|
109 |
+
template <typename T, typename U>
|
110 |
+
array randint(
|
111 |
+
T low,
|
112 |
+
U high,
|
113 |
+
const std::vector<int>& shape,
|
114 |
+
Dtype dtype = int32,
|
115 |
+
const std::optional<array>& key = std::nullopt,
|
116 |
+
StreamOrDevice s = {}) {
|
117 |
+
return randint(array(low), array(high), shape, dtype, key, to_stream(s));
|
118 |
+
};
|
119 |
+
|
120 |
+
/** Generate binary variables with probability to be true equal to p */
|
121 |
+
array bernoulli(
|
122 |
+
const array& p,
|
123 |
+
const std::vector<int>& shape,
|
124 |
+
const std::optional<array>& key = std::nullopt,
|
125 |
+
StreamOrDevice s = {});
|
126 |
+
array bernoulli(
|
127 |
+
const array& p,
|
128 |
+
const std::optional<array>& key = std::nullopt,
|
129 |
+
StreamOrDevice s = {});
|
130 |
+
|
131 |
+
template <typename T>
|
132 |
+
array bernoulli(
|
133 |
+
T p,
|
134 |
+
const std::optional<array>& key = std::nullopt,
|
135 |
+
StreamOrDevice s = {}) {
|
136 |
+
return bernoulli(array(p), key, s);
|
137 |
+
};
|
138 |
+
|
139 |
+
template <typename T>
|
140 |
+
array bernoulli(
|
141 |
+
T p,
|
142 |
+
const std::vector<int>& shape,
|
143 |
+
const std::optional<array>& key = std::nullopt,
|
144 |
+
StreamOrDevice s = {}) {
|
145 |
+
return bernoulli(array(p), shape, key, s);
|
146 |
+
};
|
147 |
+
|
148 |
+
array bernoulli(
|
149 |
+
const std::optional<array>& key = std::nullopt,
|
150 |
+
StreamOrDevice s = {});
|
151 |
+
|
152 |
+
array truncated_normal(
|
153 |
+
const array& lower,
|
154 |
+
const array& upper,
|
155 |
+
const std::vector<int>& shape,
|
156 |
+
Dtype dtype = float32,
|
157 |
+
const std::optional<array>& key = std::nullopt,
|
158 |
+
StreamOrDevice s = {});
|
159 |
+
|
160 |
+
array truncated_normal(
|
161 |
+
const array& lower,
|
162 |
+
const array& upper,
|
163 |
+
Dtype dtype = float32,
|
164 |
+
const std::optional<array>& key = std::nullopt,
|
165 |
+
StreamOrDevice s = {});
|
166 |
+
|
167 |
+
array gumbel(
|
168 |
+
const std::vector<int>& shape,
|
169 |
+
Dtype dtype = float32,
|
170 |
+
const std::optional<array>& key = std::nullopt,
|
171 |
+
StreamOrDevice s = {});
|
172 |
+
|
173 |
+
array categorical(
|
174 |
+
const array& logits,
|
175 |
+
int axis,
|
176 |
+
const std::vector<int>& shape,
|
177 |
+
const std::optional<array>& key = std::nullopt,
|
178 |
+
StreamOrDevice s = {});
|
179 |
+
|
180 |
+
array categorical(
|
181 |
+
const array& logits_,
|
182 |
+
int axis,
|
183 |
+
int num_samples,
|
184 |
+
const std::optional<array>& key = std::nullopt,
|
185 |
+
StreamOrDevice s = {});
|
186 |
+
|
187 |
+
array categorical(
|
188 |
+
const array& logits,
|
189 |
+
int axis = -1,
|
190 |
+
const std::optional<array>& key = std::nullopt,
|
191 |
+
StreamOrDevice s = {});
|
192 |
+
|
193 |
+
} // namespace mlx::core::random
|
lib/python3.11/site-packages/mlx/include/mlx/scheduler.h
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include <atomic>
|
6 |
+
#include <future>
|
7 |
+
#include <queue>
|
8 |
+
#include <thread>
|
9 |
+
#include <unordered_map>
|
10 |
+
|
11 |
+
#include "mlx/backend/metal/metal.h"
|
12 |
+
#include "mlx/device.h"
|
13 |
+
#include "mlx/stream.h"
|
14 |
+
|
15 |
+
namespace mlx::core::scheduler {
|
16 |
+
|
17 |
+
struct StreamThread {
|
18 |
+
std::mutex mtx;
|
19 |
+
std::queue<std::function<void()>> q;
|
20 |
+
std::condition_variable cond;
|
21 |
+
bool stop;
|
22 |
+
Stream stream;
|
23 |
+
std::thread thread;
|
24 |
+
|
25 |
+
StreamThread(Stream stream)
|
26 |
+
: stop(false), stream(stream), thread(&StreamThread::thread_fn, this) {}
|
27 |
+
|
28 |
+
~StreamThread() {
|
29 |
+
{
|
30 |
+
std::unique_lock<std::mutex> lk(mtx);
|
31 |
+
stop = true;
|
32 |
+
}
|
33 |
+
cond.notify_one();
|
34 |
+
thread.join();
|
35 |
+
}
|
36 |
+
|
37 |
+
void thread_fn() {
|
38 |
+
auto thread_pool = metal::new_scoped_memory_pool();
|
39 |
+
metal::new_stream(stream);
|
40 |
+
while (true) {
|
41 |
+
std::function<void()> task;
|
42 |
+
{
|
43 |
+
std::unique_lock<std::mutex> lk(mtx);
|
44 |
+
cond.wait(lk, [this] { return !this->q.empty() || this->stop; });
|
45 |
+
if (q.empty() && stop) {
|
46 |
+
return;
|
47 |
+
}
|
48 |
+
task = std::move(q.front());
|
49 |
+
q.pop();
|
50 |
+
}
|
51 |
+
task();
|
52 |
+
}
|
53 |
+
}
|
54 |
+
|
55 |
+
template <typename F>
|
56 |
+
void enqueue(F&& f) {
|
57 |
+
{
|
58 |
+
std::unique_lock<std::mutex> lk(mtx);
|
59 |
+
if (stop) {
|
60 |
+
throw std::runtime_error(
|
61 |
+
"Cannot enqueue work after stream is stopped.");
|
62 |
+
}
|
63 |
+
q.emplace(std::forward<F>(f));
|
64 |
+
}
|
65 |
+
cond.notify_one();
|
66 |
+
}
|
67 |
+
};
|
68 |
+
|
69 |
+
class Scheduler {
|
70 |
+
public:
|
71 |
+
Scheduler() : n_active_tasks_(0) {
|
72 |
+
if (metal::is_available()) {
|
73 |
+
default_streams_.insert({Device::gpu, new_stream(Device::gpu)});
|
74 |
+
}
|
75 |
+
default_streams_.insert({Device::cpu, new_stream(Device::cpu)});
|
76 |
+
}
|
77 |
+
|
78 |
+
// Not copyable or moveable
|
79 |
+
Scheduler(const Scheduler&) = delete;
|
80 |
+
Scheduler(Scheduler&&) = delete;
|
81 |
+
Scheduler& operator=(const Scheduler&) = delete;
|
82 |
+
Scheduler& operator=(Scheduler&&) = delete;
|
83 |
+
|
84 |
+
Stream new_stream(const Device& d) {
|
85 |
+
auto stream = Stream(streams_.size(), d);
|
86 |
+
streams_.push_back(new StreamThread{stream});
|
87 |
+
return stream;
|
88 |
+
}
|
89 |
+
|
90 |
+
template <typename F>
|
91 |
+
void enqueue(const Stream& stream, F&& f);
|
92 |
+
|
93 |
+
Stream get_default_stream(const Device& d) {
|
94 |
+
return default_streams_.at(d.type);
|
95 |
+
}
|
96 |
+
|
97 |
+
void set_default_stream(const Stream& s) {
|
98 |
+
default_streams_.at(s.device.type) = s;
|
99 |
+
}
|
100 |
+
|
101 |
+
void notify_new_task(const Stream& stream) {
|
102 |
+
{
|
103 |
+
std::unique_lock<std::mutex> lk(mtx);
|
104 |
+
n_active_tasks_++;
|
105 |
+
}
|
106 |
+
completion_cv.notify_all();
|
107 |
+
}
|
108 |
+
|
109 |
+
void notify_task_completion(const Stream& stream) {
|
110 |
+
{
|
111 |
+
std::unique_lock<std::mutex> lk(mtx);
|
112 |
+
n_active_tasks_--;
|
113 |
+
}
|
114 |
+
completion_cv.notify_all();
|
115 |
+
}
|
116 |
+
|
117 |
+
int n_active_tasks() const {
|
118 |
+
return n_active_tasks_;
|
119 |
+
}
|
120 |
+
|
121 |
+
void wait_for_one() {
|
122 |
+
std::unique_lock<std::mutex> lk(mtx);
|
123 |
+
int n_tasks_old = n_active_tasks();
|
124 |
+
if (n_tasks_old > 1) {
|
125 |
+
completion_cv.wait(lk, [this, n_tasks_old] {
|
126 |
+
return this->n_active_tasks() != n_tasks_old;
|
127 |
+
});
|
128 |
+
}
|
129 |
+
}
|
130 |
+
|
131 |
+
~Scheduler() {
|
132 |
+
for (auto s : streams_) {
|
133 |
+
delete s;
|
134 |
+
}
|
135 |
+
}
|
136 |
+
|
137 |
+
private:
|
138 |
+
int n_active_tasks_;
|
139 |
+
std::vector<StreamThread*> streams_;
|
140 |
+
std::unordered_map<Device::DeviceType, Stream> default_streams_;
|
141 |
+
std::condition_variable completion_cv;
|
142 |
+
std::mutex mtx;
|
143 |
+
};
|
144 |
+
|
145 |
+
template <typename F>
|
146 |
+
void Scheduler::enqueue(const Stream& stream, F&& f) {
|
147 |
+
streams_[stream.index]->enqueue(std::forward<F>(f));
|
148 |
+
}
|
149 |
+
|
150 |
+
Scheduler& scheduler();
|
151 |
+
|
152 |
+
template <typename F>
|
153 |
+
void enqueue(const Stream& stream, F&& f) {
|
154 |
+
scheduler().enqueue(stream, std::forward<F>(f));
|
155 |
+
}
|
156 |
+
|
157 |
+
inline int n_active_tasks() {
|
158 |
+
return scheduler().n_active_tasks();
|
159 |
+
}
|
160 |
+
|
161 |
+
inline void notify_new_task(const Stream& stream) {
|
162 |
+
scheduler().notify_new_task(stream);
|
163 |
+
}
|
164 |
+
|
165 |
+
inline void notify_task_completion(const Stream& stream) {
|
166 |
+
scheduler().notify_task_completion(stream);
|
167 |
+
}
|
168 |
+
|
169 |
+
inline void wait_for_one() {
|
170 |
+
scheduler().wait_for_one();
|
171 |
+
}
|
172 |
+
|
173 |
+
} // namespace mlx::core::scheduler
|
lib/python3.11/site-packages/mlx/include/mlx/stream.h
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include "mlx/device.h"
|
6 |
+
|
7 |
+
namespace mlx::core {
|
8 |
+
|
9 |
+
struct Stream {
|
10 |
+
int index;
|
11 |
+
Device device;
|
12 |
+
explicit Stream(int index, Device device) : index(index), device(device) {}
|
13 |
+
};
|
14 |
+
|
15 |
+
/** Get the default stream for the given device. */
|
16 |
+
Stream default_stream(Device d);
|
17 |
+
|
18 |
+
/** Make the stream the default for its device. */
|
19 |
+
void set_default_stream(Stream s);
|
20 |
+
|
21 |
+
/** Make a new stream on the given device. */
|
22 |
+
Stream new_stream(Device d);
|
23 |
+
|
24 |
+
inline bool operator==(const Stream& lhs, const Stream& rhs) {
|
25 |
+
return lhs.index == rhs.index;
|
26 |
+
}
|
27 |
+
|
28 |
+
inline bool operator!=(const Stream& lhs, const Stream& rhs) {
|
29 |
+
return !(lhs == rhs);
|
30 |
+
}
|
31 |
+
|
32 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/transforms.h
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include "array.h"
|
6 |
+
|
7 |
+
namespace mlx::core {
|
8 |
+
|
9 |
+
/** Fuse equivalent arrays to avoid duplicate execution. */
|
10 |
+
void simplify(const std::vector<array>& outputs);
|
11 |
+
|
12 |
+
template <typename... Arrays>
|
13 |
+
void simplify(Arrays... outputs) {
|
14 |
+
simplify(std::vector<array>{std::forward<Arrays>(outputs)...});
|
15 |
+
}
|
16 |
+
|
17 |
+
void eval(const std::vector<array>& outputs, bool retain_graph = false);
|
18 |
+
|
19 |
+
template <typename... Arrays>
|
20 |
+
void eval(Arrays... outputs) {
|
21 |
+
eval(std::vector<array>{std::forward<Arrays>(outputs)...}, false);
|
22 |
+
}
|
23 |
+
|
24 |
+
/**
|
25 |
+
* Computes the output and vector-Jacobian product (VJP) of a function.
|
26 |
+
*
|
27 |
+
* Computes the vector-Jacobian product of the vector of cotangents with the
|
28 |
+
* Jacobian of the function evaluated at the primals. Returns a pair of
|
29 |
+
* vectors of output arrays and VJP arrays.
|
30 |
+
**/
|
31 |
+
std::pair<std::vector<array>, std::vector<array>> vjp(
|
32 |
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
33 |
+
const std::vector<array>& primals,
|
34 |
+
const std::vector<array>& cotangents);
|
35 |
+
|
36 |
+
/**
|
37 |
+
* Computes the output and vector-Jacobian product (VJP) of a unary function.
|
38 |
+
*/
|
39 |
+
std::pair<array, array> vjp(
|
40 |
+
const std::function<array(const array&)>& fun,
|
41 |
+
const array& primal,
|
42 |
+
const array& cotangent);
|
43 |
+
|
44 |
+
/**
|
45 |
+
* Computes the output and Jacobian-vector product (JVP) of a function.
|
46 |
+
*
|
47 |
+
* Computes the Jacobian-vector product of the Jacobian of the function
|
48 |
+
* evaluated at the primals with the vector of tangents. Returns a pair of
|
49 |
+
* vectors of output arrays and JVP arrays.
|
50 |
+
**/
|
51 |
+
std::pair<std::vector<array>, std::vector<array>> jvp(
|
52 |
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
53 |
+
const std::vector<array>& primals,
|
54 |
+
const std::vector<array>& tangents);
|
55 |
+
|
56 |
+
/**
|
57 |
+
* Computes the output and Jacobian-vector product (JVP) of a unary function.
|
58 |
+
*/
|
59 |
+
std::pair<array, array> jvp(
|
60 |
+
const std::function<array(const array&)>& fun,
|
61 |
+
const array& primal,
|
62 |
+
const array& tangent);
|
63 |
+
|
64 |
+
// Return type of general value_and_grad: a function which takes an input
|
65 |
+
// vector of arrays and returns a pair of vectors of arrays one for the
|
66 |
+
// values and one for the gradients wrt the first value.
|
67 |
+
using ValueAndGradFn =
|
68 |
+
std::function<std::pair<std::vector<array>, std::vector<array>>(
|
69 |
+
const std::vector<array>&)>;
|
70 |
+
using SimpleValueAndGradFn = std::function<std::pair<array, std::vector<array>>(
|
71 |
+
const std::vector<array>&)>;
|
72 |
+
|
73 |
+
/**
|
74 |
+
* Returns a function which computes the value and gradient of the input
|
75 |
+
* function with respect to a vector of input arrays.
|
76 |
+
**/
|
77 |
+
ValueAndGradFn value_and_grad(
|
78 |
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
79 |
+
const std::vector<int>& argnums);
|
80 |
+
|
81 |
+
/**
|
82 |
+
* Returns a function which computes the value and gradient of the input
|
83 |
+
* function with respect to a single input array.
|
84 |
+
**/
|
85 |
+
ValueAndGradFn inline value_and_grad(
|
86 |
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
87 |
+
int argnum = 0) {
|
88 |
+
return value_and_grad(fun, std::vector<int>{argnum});
|
89 |
+
}
|
90 |
+
|
91 |
+
/**
|
92 |
+
* Returns a function which computes the value and gradient of the unary
|
93 |
+
* input function.
|
94 |
+
**/
|
95 |
+
std::function<std::pair<array, array>(const array&)> inline value_and_grad(
|
96 |
+
const std::function<array(const array&)>& fun) {
|
97 |
+
return [fun](auto inputs) { return vjp(fun, inputs, array(1.0f)); };
|
98 |
+
}
|
99 |
+
|
100 |
+
SimpleValueAndGradFn inline value_and_grad(
|
101 |
+
const std::function<array(const std::vector<array>&)>& fun,
|
102 |
+
const std::vector<int>& argnums) {
|
103 |
+
return [fun, argnums](auto inputs) {
|
104 |
+
auto result = value_and_grad(
|
105 |
+
[fun](auto inputs) { return std::vector<array>{fun(inputs)}; },
|
106 |
+
argnums)(inputs);
|
107 |
+
|
108 |
+
return std::make_pair(result.first[0], result.second);
|
109 |
+
};
|
110 |
+
}
|
111 |
+
|
112 |
+
SimpleValueAndGradFn inline value_and_grad(
|
113 |
+
const std::function<array(const std::vector<array>&)>& fun,
|
114 |
+
int argnum = 0) {
|
115 |
+
return value_and_grad(fun, std::vector<int>{argnum});
|
116 |
+
}
|
117 |
+
|
118 |
+
/**
|
119 |
+
* Returns a function which computes the gradient of the input function with
|
120 |
+
* respect to a vector of input arrays.
|
121 |
+
*
|
122 |
+
* The function being differentiated takes a vector of arrays and returns an
|
123 |
+
* array. The vector of `argnums` specifies which the arguments to compute
|
124 |
+
* the gradient with respect to. At least one argument must be specified.
|
125 |
+
**/
|
126 |
+
std::function<std::vector<array>(const std::vector<array>&)> inline grad(
|
127 |
+
const std::function<array(const std::vector<array>&)>& fun,
|
128 |
+
const std::vector<int>& argnums) {
|
129 |
+
auto fn = value_and_grad(fun, argnums);
|
130 |
+
return [fn](const std::vector<array>& inputs) { return fn(inputs).second; };
|
131 |
+
}
|
132 |
+
|
133 |
+
/**
|
134 |
+
* Returns a function which computes the gradient of the input function with
|
135 |
+
* respect to a single input array.
|
136 |
+
*
|
137 |
+
* The function being differentiated takes a vector of arrays and returns an
|
138 |
+
* array. The optional `argnum` index specifies which the argument to compute
|
139 |
+
* the gradient with respect to and defaults to 0.
|
140 |
+
**/
|
141 |
+
std::function<std::vector<array>(const std::vector<array>&)> inline grad(
|
142 |
+
const std::function<array(const std::vector<array>&)>& fun,
|
143 |
+
int argnum = 0) {
|
144 |
+
return grad(fun, std::vector<int>{argnum});
|
145 |
+
}
|
146 |
+
|
147 |
+
/**
|
148 |
+
* Returns a function which computes the gradient of the unary input function.
|
149 |
+
**/
|
150 |
+
std::function<array(const array&)> inline grad(
|
151 |
+
const std::function<array(const array&)>& fun) {
|
152 |
+
auto fn = value_and_grad(fun);
|
153 |
+
return [fn](const array& input) { return fn(input).second; };
|
154 |
+
}
|
155 |
+
|
156 |
+
/**
|
157 |
+
* Automatically vectorize a unary function over the requested axes.
|
158 |
+
*/
|
159 |
+
std::function<array(const array&)> vmap(
|
160 |
+
const std::function<array(const array&)>& fun,
|
161 |
+
int in_axis = 0,
|
162 |
+
int out_axis = 0);
|
163 |
+
|
164 |
+
/**
|
165 |
+
* Automatically vectorize a binary function over the requested axes.
|
166 |
+
*/
|
167 |
+
std::function<array(const array&, const array&)> vmap(
|
168 |
+
const std::function<array(const array&, const array&)>& fun,
|
169 |
+
int in_axis_a = 0,
|
170 |
+
int in_axis_b = 0,
|
171 |
+
int out_axis = 0);
|
172 |
+
|
173 |
+
/**
|
174 |
+
* Automatically vectorize a function over the requested axes.
|
175 |
+
*
|
176 |
+
* The input function to `vmap` takes as an argument a vector of arrays and
|
177 |
+
* returns a vector of arrays. Optionally specify the axes to vectorize over
|
178 |
+
* with `in_axes` and `out_axes`, otherwise a default of 0 is used.
|
179 |
+
* Returns a vectorized function with the same signature as the input
|
180 |
+
* function.
|
181 |
+
*/
|
182 |
+
std::function<std::vector<array>(const std::vector<array>&)> vmap(
|
183 |
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
184 |
+
const std::vector<int>& in_axes = {},
|
185 |
+
const std::vector<int>& out_axes = {});
|
186 |
+
|
187 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/transforms_impl.h
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
namespace mlx::core::detail {
|
4 |
+
|
5 |
+
std::pair<std::vector<array>, std::vector<array>> vmap_trace(
|
6 |
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
7 |
+
const std::vector<array>& inputs,
|
8 |
+
const std::vector<int>& in_axes);
|
9 |
+
|
10 |
+
std::vector<array> vmap_replace(
|
11 |
+
const std::vector<array>& inputs,
|
12 |
+
const std::vector<array>& s_inputs,
|
13 |
+
const std::vector<array>& s_outputs,
|
14 |
+
const std::vector<int>& in_axes,
|
15 |
+
const std::vector<int>& out_axes);
|
16 |
+
|
17 |
+
} // namespace mlx::core::detail
|
lib/python3.11/site-packages/mlx/include/mlx/types/bf16.h
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include <algorithm>
|
6 |
+
#include <cmath>
|
7 |
+
#include <cstdint>
|
8 |
+
#include <vector>
|
9 |
+
|
10 |
+
#define __MLX_BFLOAT_NAN__ 0x7FC0
|
11 |
+
|
12 |
+
namespace mlx::core {
|
13 |
+
|
14 |
+
namespace {
|
15 |
+
union float_bits_bf16 {
|
16 |
+
float f;
|
17 |
+
uint32_t u;
|
18 |
+
};
|
19 |
+
} // namespace
|
20 |
+
|
21 |
+
struct _MLX_BFloat16 {
|
22 |
+
uint16_t bits_;
|
23 |
+
|
24 |
+
// Default constructor
|
25 |
+
_MLX_BFloat16() = default;
|
26 |
+
|
27 |
+
// Default copy constructor
|
28 |
+
_MLX_BFloat16(_MLX_BFloat16 const&) = default;
|
29 |
+
|
30 |
+
// Appease std::vector<bool> for being special
|
31 |
+
_MLX_BFloat16& operator=(std::vector<bool>::reference x) {
|
32 |
+
bits_ = x;
|
33 |
+
return *this;
|
34 |
+
}
|
35 |
+
|
36 |
+
_MLX_BFloat16& operator=(const float& x) {
|
37 |
+
return (*this = _MLX_BFloat16(x));
|
38 |
+
}
|
39 |
+
|
40 |
+
// From float32
|
41 |
+
_MLX_BFloat16(const float& x) {
|
42 |
+
if (std::isnan(x)) {
|
43 |
+
bits_ = __MLX_BFLOAT_NAN__;
|
44 |
+
} else {
|
45 |
+
// Union
|
46 |
+
float_bits_bf16 in;
|
47 |
+
|
48 |
+
// Take bits
|
49 |
+
in.f = x;
|
50 |
+
|
51 |
+
// Round to nearest even
|
52 |
+
in.u += (in.u >> 16 & 1) + uint32_t(0x7FFF);
|
53 |
+
|
54 |
+
// Take upper 16 bits
|
55 |
+
bits_ = in.u >> 16;
|
56 |
+
}
|
57 |
+
}
|
58 |
+
|
59 |
+
// To float32
|
60 |
+
operator float() const {
|
61 |
+
// Union
|
62 |
+
float_bits_bf16 out;
|
63 |
+
|
64 |
+
// Upper 16 bits are the data and lower 16 bits are 0s
|
65 |
+
out.u = ((uint32_t)bits_) << 16;
|
66 |
+
|
67 |
+
return out.f;
|
68 |
+
}
|
69 |
+
};
|
70 |
+
|
71 |
+
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
72 |
+
inline otype __operator__(atype lhs, btype rhs) { \
|
73 |
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
74 |
+
}
|
75 |
+
|
76 |
+
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
77 |
+
inline otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
|
78 |
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
79 |
+
} \
|
80 |
+
inline otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
|
81 |
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
82 |
+
}
|
83 |
+
|
84 |
+
// Operators
|
85 |
+
#define bfloat_binop(_op_, _operator_) \
|
86 |
+
bfloat_binop_base( \
|
87 |
+
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
|
88 |
+
bfloat_binop_helper(_op_, _operator_, float, float, float); \
|
89 |
+
bfloat_binop_helper(_op_, _operator_, double, double, double); \
|
90 |
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, bool, float); \
|
91 |
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
|
92 |
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
|
93 |
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
|
94 |
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
|
95 |
+
|
96 |
+
bfloat_binop(+, operator+);
|
97 |
+
bfloat_binop(-, operator-);
|
98 |
+
bfloat_binop(*, operator*);
|
99 |
+
bfloat_binop(/, operator/);
|
100 |
+
|
101 |
+
#undef bfloat_binop
|
102 |
+
|
103 |
+
// Comparison ops
|
104 |
+
#define bfloat_compop(__op__, __operator__) \
|
105 |
+
bfloat_binop_base( \
|
106 |
+
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
|
107 |
+
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
|
108 |
+
bfloat_binop_helper(__op__, __operator__, bool, double, double); \
|
109 |
+
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
110 |
+
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
111 |
+
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
112 |
+
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
113 |
+
|
114 |
+
bfloat_compop(>, operator>);
|
115 |
+
bfloat_compop(<, operator<);
|
116 |
+
bfloat_compop(>=, operator>=);
|
117 |
+
bfloat_compop(<=, operator<=);
|
118 |
+
bfloat_compop(==, operator==);
|
119 |
+
bfloat_compop(!=, operator!=);
|
120 |
+
|
121 |
+
#undef bfloat_compop
|
122 |
+
|
123 |
+
// Negative
|
124 |
+
inline _MLX_BFloat16 operator-(_MLX_BFloat16 lhs) {
|
125 |
+
return -static_cast<float>(lhs);
|
126 |
+
}
|
127 |
+
|
128 |
+
// Inplace ops
|
129 |
+
#define bfloat_inplace_op(__op__, __operator__) \
|
130 |
+
inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, const float& rhs) { \
|
131 |
+
lhs = lhs __op__ rhs; \
|
132 |
+
return lhs; \
|
133 |
+
} \
|
134 |
+
inline float& __operator__(float& lhs, _MLX_BFloat16 rhs) { \
|
135 |
+
lhs = lhs __op__ rhs; \
|
136 |
+
return lhs; \
|
137 |
+
}
|
138 |
+
|
139 |
+
bfloat_inplace_op(+, operator+=);
|
140 |
+
bfloat_inplace_op(-, operator-=);
|
141 |
+
bfloat_inplace_op(*, operator*=);
|
142 |
+
bfloat_inplace_op(/, operator/=);
|
143 |
+
|
144 |
+
#undef bfloat_inplace_op
|
145 |
+
|
146 |
+
// Bitwise ops
|
147 |
+
|
148 |
+
#define bfloat_bitop(__op__, __operator__) \
|
149 |
+
inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, _MLX_BFloat16 rhs) { \
|
150 |
+
_MLX_BFloat16 out; \
|
151 |
+
out.bits_ = lhs.bits_ __op__ rhs.bits_; \
|
152 |
+
return out; \
|
153 |
+
} \
|
154 |
+
inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, uint16_t rhs) { \
|
155 |
+
_MLX_BFloat16 out; \
|
156 |
+
out.bits_ = lhs.bits_ __op__ rhs; \
|
157 |
+
return out; \
|
158 |
+
} \
|
159 |
+
inline _MLX_BFloat16 __operator__(uint16_t lhs, _MLX_BFloat16 rhs) { \
|
160 |
+
_MLX_BFloat16 out; \
|
161 |
+
out.bits_ = lhs __op__ rhs.bits_; \
|
162 |
+
return out; \
|
163 |
+
}
|
164 |
+
|
165 |
+
bfloat_bitop(|, operator|);
|
166 |
+
bfloat_bitop(&, operator&);
|
167 |
+
bfloat_bitop(^, operator^);
|
168 |
+
|
169 |
+
#undef bfloat_bitop
|
170 |
+
|
171 |
+
#define bfloat_inplace_bitop(__op__, __operator__) \
|
172 |
+
inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
|
173 |
+
lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \
|
174 |
+
return lhs; \
|
175 |
+
} \
|
176 |
+
inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, uint16_t rhs) { \
|
177 |
+
lhs.bits_ = lhs.bits_ __op__ rhs; \
|
178 |
+
return lhs; \
|
179 |
+
}
|
180 |
+
|
181 |
+
bfloat_inplace_bitop(|, operator|=);
|
182 |
+
bfloat_inplace_bitop(&, operator&=);
|
183 |
+
bfloat_inplace_bitop(^, operator^=);
|
184 |
+
|
185 |
+
#undef bfloat_inplace_bitop
|
186 |
+
|
187 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/types/complex.h
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
#include <complex>
|
5 |
+
#include "mlx/types/half_types.h"
|
6 |
+
|
7 |
+
namespace mlx::core {
|
8 |
+
|
9 |
+
struct complex64_t;
|
10 |
+
|
11 |
+
template <typename T>
|
12 |
+
static constexpr bool can_convert_to_complex64 =
|
13 |
+
!std::is_same_v<T, complex64_t> && std::is_convertible_v<T, float>;
|
14 |
+
|
15 |
+
struct complex64_t : public std::complex<float> {
|
16 |
+
complex64_t(float v, float u) : std::complex<float>(v, u){};
|
17 |
+
complex64_t(std::complex<float> v) : std::complex<float>(v){};
|
18 |
+
|
19 |
+
template <
|
20 |
+
typename T,
|
21 |
+
typename = typename std::enable_if<can_convert_to_complex64<T>>::type>
|
22 |
+
complex64_t(T x) : std::complex<float>(x){};
|
23 |
+
|
24 |
+
operator float() const {
|
25 |
+
return real();
|
26 |
+
};
|
27 |
+
};
|
28 |
+
|
29 |
+
inline bool operator>=(const complex64_t& a, const complex64_t& b) {
|
30 |
+
return (a.real() > b.real()) ||
|
31 |
+
(a.real() == b.real() && a.imag() >= b.imag());
|
32 |
+
}
|
33 |
+
|
34 |
+
inline bool operator>(const complex64_t& a, const complex64_t& b) {
|
35 |
+
return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());
|
36 |
+
}
|
37 |
+
|
38 |
+
inline bool operator<=(const complex64_t& a, const complex64_t& b) {
|
39 |
+
return operator>=(b, a);
|
40 |
+
}
|
41 |
+
|
42 |
+
inline bool operator<(const complex64_t& a, const complex64_t& b) {
|
43 |
+
return operator>(b, a);
|
44 |
+
}
|
45 |
+
|
46 |
+
inline complex64_t operator-(const complex64_t& v) {
|
47 |
+
return -static_cast<std::complex<float>>(v);
|
48 |
+
}
|
49 |
+
|
50 |
+
// clang-format off
|
51 |
+
#define complex_binop_helper(_op_, _operator_, itype) \
|
52 |
+
inline complex64_t _operator_(itype x, const complex64_t& y) { \
|
53 |
+
return x _op_ static_cast<std::complex<float>>(y); \
|
54 |
+
} \
|
55 |
+
inline complex64_t _operator_(const complex64_t& x, itype y) { \
|
56 |
+
return static_cast<std::complex<float>>(x) _op_ y; \
|
57 |
+
}
|
58 |
+
|
59 |
+
#define complex_binop(_op_, _operator_) \
|
60 |
+
inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \
|
61 |
+
return static_cast<std::complex<float>>(x) \
|
62 |
+
_op_ static_cast<std::complex<float>>(y); \
|
63 |
+
} \
|
64 |
+
complex_binop_helper(_op_, _operator_, bool) \
|
65 |
+
complex_binop_helper(_op_, _operator_, uint32_t) \
|
66 |
+
complex_binop_helper(_op_, _operator_, uint64_t) \
|
67 |
+
complex_binop_helper(_op_, _operator_, int32_t) \
|
68 |
+
complex_binop_helper(_op_, _operator_, int64_t) \
|
69 |
+
complex_binop_helper(_op_, _operator_, float16_t) \
|
70 |
+
complex_binop_helper(_op_, _operator_, bfloat16_t) \
|
71 |
+
complex_binop_helper(_op_, _operator_, const std::complex<float>&) \
|
72 |
+
complex_binop_helper(_op_, _operator_, float)
|
73 |
+
// clang-format on
|
74 |
+
|
75 |
+
complex_binop(+, operator+)
|
76 |
+
|
77 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/types/fp16.h
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include <algorithm>
|
6 |
+
#include <cmath>
|
7 |
+
#include <cstdint>
|
8 |
+
#include <vector>
|
9 |
+
|
10 |
+
#define __MLX_HALF_NAN__ 0x7D00
|
11 |
+
|
12 |
+
namespace mlx::core {
|
13 |
+
|
14 |
+
namespace {
|
15 |
+
union float_bits_fp16 {
|
16 |
+
float f;
|
17 |
+
uint32_t u;
|
18 |
+
};
|
19 |
+
} // namespace
|
20 |
+
|
21 |
+
struct _MLX_Float16 {
|
22 |
+
uint16_t bits_;
|
23 |
+
|
24 |
+
// Default constructor
|
25 |
+
_MLX_Float16() = default;
|
26 |
+
|
27 |
+
// Default copy constructor
|
28 |
+
_MLX_Float16(_MLX_Float16 const&) = default;
|
29 |
+
|
30 |
+
// Appease std::vector<bool> for being special
|
31 |
+
_MLX_Float16& operator=(std::vector<bool>::reference x) {
|
32 |
+
bits_ = x;
|
33 |
+
return *this;
|
34 |
+
}
|
35 |
+
|
36 |
+
_MLX_Float16& operator=(const float& x) {
|
37 |
+
return (*this = _MLX_Float16(x));
|
38 |
+
}
|
39 |
+
|
40 |
+
// From float32
|
41 |
+
_MLX_Float16(const float& x) : bits_(0) {
|
42 |
+
// Conversion following
|
43 |
+
// https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h
|
44 |
+
|
45 |
+
// Union
|
46 |
+
float_bits_fp16 in;
|
47 |
+
|
48 |
+
// Take fp32 bits
|
49 |
+
in.f = x;
|
50 |
+
|
51 |
+
// Find and take sign bit
|
52 |
+
uint32_t x_sign_32 = in.u & uint32_t(0x80000000);
|
53 |
+
uint16_t x_sign_16 = (x_sign_32 >> 16);
|
54 |
+
|
55 |
+
if (std::isnan(x)) {
|
56 |
+
bits_ = x_sign_16 | uint16_t(__MLX_HALF_NAN__);
|
57 |
+
} else {
|
58 |
+
// Union
|
59 |
+
float_bits_fp16 inf_scale, zero_scale, magic_bits;
|
60 |
+
|
61 |
+
// Find exponent bits and take the max supported by half
|
62 |
+
uint32_t x_expo_32 = in.u & uint32_t(0x7f800000);
|
63 |
+
uint32_t max_expo_32 = uint32_t(0x38800000);
|
64 |
+
x_expo_32 = x_expo_32 < max_expo_32 ? max_expo_32 : x_expo_32;
|
65 |
+
x_expo_32 += uint32_t(15) << 23;
|
66 |
+
|
67 |
+
// Handle scaling to inf as needed
|
68 |
+
inf_scale.u = uint32_t(0x77800000);
|
69 |
+
zero_scale.u = uint32_t(0x08800000);
|
70 |
+
|
71 |
+
// Combine with magic and let addition do rounding
|
72 |
+
magic_bits.u = x_expo_32;
|
73 |
+
magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f;
|
74 |
+
|
75 |
+
// Take the lower 5 bits of the exponent
|
76 |
+
uint32_t x_expo_16 = ((magic_bits.u >> 13) & uint32_t(0x7c00));
|
77 |
+
|
78 |
+
// Collect the lower 12 bits which have the mantissa
|
79 |
+
uint32_t x_mant_16 = magic_bits.u & uint32_t(0x0fff);
|
80 |
+
|
81 |
+
// Combine sign, exp and mantissa
|
82 |
+
bits_ = (x_sign_16 | uint16_t(x_expo_16 + x_mant_16));
|
83 |
+
}
|
84 |
+
}
|
85 |
+
|
86 |
+
// To float32
|
87 |
+
operator float() const {
|
88 |
+
// Conversion following
|
89 |
+
// https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h
|
90 |
+
|
91 |
+
// Union
|
92 |
+
float_bits_fp16 out;
|
93 |
+
|
94 |
+
uint32_t x_sign_32 = (bits_ << 16) & uint32_t(0x80000000);
|
95 |
+
uint32_t base = (bits_ << 16);
|
96 |
+
uint32_t two_base = base + base;
|
97 |
+
|
98 |
+
uint32_t denorm_max = 1u << 27;
|
99 |
+
if (two_base < denorm_max) {
|
100 |
+
out.u = uint32_t(126) << 23; // magic mask
|
101 |
+
out.u |= (two_base >> 17); // Bits from fp16
|
102 |
+
out.f -= 0.5f; // magic bias
|
103 |
+
} else {
|
104 |
+
out.u = uint32_t(0xE0) << 23; // exponent offset
|
105 |
+
out.u += (two_base >> 4); // Bits from fp16
|
106 |
+
float out_unscaled = out.f; // Store value
|
107 |
+
out.u = uint32_t(0x7800000); // exponent scale
|
108 |
+
out.f *= out_unscaled;
|
109 |
+
}
|
110 |
+
|
111 |
+
// Add sign
|
112 |
+
out.u |= x_sign_32;
|
113 |
+
|
114 |
+
return out.f;
|
115 |
+
}
|
116 |
+
};
|
117 |
+
|
118 |
+
#define half_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
119 |
+
inline otype __operator__(atype lhs, btype rhs) { \
|
120 |
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
121 |
+
}
|
122 |
+
|
123 |
+
#define half_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
124 |
+
inline otype __operator__(_MLX_Float16 lhs, itype rhs) { \
|
125 |
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
126 |
+
} \
|
127 |
+
inline otype __operator__(itype lhs, _MLX_Float16 rhs) { \
|
128 |
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
129 |
+
}
|
130 |
+
|
131 |
+
// Operators
|
132 |
+
#define half_binop(__op__, __operator__) \
|
133 |
+
half_binop_base( \
|
134 |
+
__op__, __operator__, _MLX_Float16, _MLX_Float16, _MLX_Float16, float); \
|
135 |
+
half_binop_helper(__op__, __operator__, float, float, float); \
|
136 |
+
half_binop_helper(__op__, __operator__, double, double, double); \
|
137 |
+
half_binop_helper(__op__, __operator__, _MLX_Float16, bool, float); \
|
138 |
+
half_binop_helper(__op__, __operator__, _MLX_Float16, int32_t, float); \
|
139 |
+
half_binop_helper(__op__, __operator__, _MLX_Float16, uint32_t, float); \
|
140 |
+
half_binop_helper(__op__, __operator__, _MLX_Float16, int64_t, float); \
|
141 |
+
half_binop_helper(__op__, __operator__, _MLX_Float16, uint64_t, float);
|
142 |
+
|
143 |
+
half_binop(+, operator+);
|
144 |
+
half_binop(-, operator-);
|
145 |
+
half_binop(*, operator*);
|
146 |
+
half_binop(/, operator/);
|
147 |
+
|
148 |
+
#undef half_binop
|
149 |
+
|
150 |
+
// Comparison ops
|
151 |
+
#define half_compop(__op__, __operator__) \
|
152 |
+
half_binop_base( \
|
153 |
+
__op__, __operator__, bool, _MLX_Float16, _MLX_Float16, float); \
|
154 |
+
half_binop_helper(__op__, __operator__, bool, float, float); \
|
155 |
+
half_binop_helper(__op__, __operator__, bool, double, double); \
|
156 |
+
half_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
157 |
+
half_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
158 |
+
half_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
159 |
+
half_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
160 |
+
|
161 |
+
half_compop(>, operator>);
|
162 |
+
half_compop(<, operator<);
|
163 |
+
half_compop(>=, operator>=);
|
164 |
+
half_compop(<=, operator<=);
|
165 |
+
half_compop(==, operator==);
|
166 |
+
half_compop(!=, operator!=);
|
167 |
+
|
168 |
+
#undef half_compop
|
169 |
+
|
170 |
+
// Negative
|
171 |
+
inline _MLX_Float16 operator-(_MLX_Float16 lhs) {
|
172 |
+
return -static_cast<float>(lhs);
|
173 |
+
}
|
174 |
+
|
175 |
+
// Inplace ops
|
176 |
+
#define half_inplace_op(__op__, __operator__) \
|
177 |
+
inline _MLX_Float16& __operator__(_MLX_Float16& lhs, const float& rhs) { \
|
178 |
+
lhs = lhs __op__ rhs; \
|
179 |
+
return lhs; \
|
180 |
+
} \
|
181 |
+
inline float& __operator__(float& lhs, _MLX_Float16 rhs) { \
|
182 |
+
lhs = lhs __op__ rhs; \
|
183 |
+
return lhs; \
|
184 |
+
}
|
185 |
+
|
186 |
+
half_inplace_op(+, operator+=);
|
187 |
+
half_inplace_op(-, operator-=);
|
188 |
+
half_inplace_op(*, operator*=);
|
189 |
+
half_inplace_op(/, operator/=);
|
190 |
+
|
191 |
+
#undef half_inplace_op
|
192 |
+
|
193 |
+
// Bitwise ops
|
194 |
+
|
195 |
+
#define half_bitop(__op__, __operator__) \
|
196 |
+
inline _MLX_Float16 __operator__(_MLX_Float16 lhs, _MLX_Float16 rhs) { \
|
197 |
+
_MLX_Float16 out; \
|
198 |
+
out.bits_ = lhs.bits_ __op__ rhs.bits_; \
|
199 |
+
return out; \
|
200 |
+
} \
|
201 |
+
inline _MLX_Float16 __operator__(_MLX_Float16 lhs, uint16_t rhs) { \
|
202 |
+
_MLX_Float16 out; \
|
203 |
+
out.bits_ = lhs.bits_ __op__ rhs; \
|
204 |
+
return out; \
|
205 |
+
} \
|
206 |
+
inline _MLX_Float16 __operator__(uint16_t lhs, _MLX_Float16 rhs) { \
|
207 |
+
_MLX_Float16 out; \
|
208 |
+
out.bits_ = lhs __op__ rhs.bits_; \
|
209 |
+
return out; \
|
210 |
+
}
|
211 |
+
|
212 |
+
half_bitop(|, operator|);
|
213 |
+
half_bitop(&, operator&);
|
214 |
+
half_bitop(^, operator^);
|
215 |
+
|
216 |
+
#undef half_bitop
|
217 |
+
|
218 |
+
#define half_inplace_bitop(__op__, __operator__) \
|
219 |
+
inline _MLX_Float16& __operator__(_MLX_Float16& lhs, _MLX_Float16 rhs) { \
|
220 |
+
lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \
|
221 |
+
return lhs; \
|
222 |
+
} \
|
223 |
+
inline _MLX_Float16& __operator__(_MLX_Float16& lhs, uint16_t rhs) { \
|
224 |
+
lhs.bits_ = lhs.bits_ __op__ rhs; \
|
225 |
+
return lhs; \
|
226 |
+
}
|
227 |
+
|
228 |
+
half_inplace_bitop(|, operator|=);
|
229 |
+
half_inplace_bitop(&, operator&=);
|
230 |
+
half_inplace_bitop(^, operator^=);
|
231 |
+
|
232 |
+
#undef half_inplace_bitop
|
233 |
+
|
234 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/types/half_types.h
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
|
5 |
+
|
6 |
+
#include <arm_fp16.h>
|
7 |
+
namespace mlx::core {
|
8 |
+
typedef __fp16 float16_t;
|
9 |
+
} // namespace mlx::core
|
10 |
+
|
11 |
+
#else
|
12 |
+
|
13 |
+
#define ADD_HALF_BINOPS
|
14 |
+
#include "mlx/types/fp16.h"
|
15 |
+
namespace mlx::core {
|
16 |
+
typedef struct _MLX_Float16 float16_t;
|
17 |
+
} // namespace mlx::core
|
18 |
+
|
19 |
+
#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
|
20 |
+
#ifdef __ARM_FEATURE_BF16
|
21 |
+
|
22 |
+
#include <arm_bf16.h>
|
23 |
+
namespace mlx::core {
|
24 |
+
typedef __bf16 bfloat16_t;
|
25 |
+
} // namespace mlx::core
|
26 |
+
|
27 |
+
#else
|
28 |
+
|
29 |
+
#define ADD_HALF_BINOPS
|
30 |
+
#include "mlx/types/bf16.h"
|
31 |
+
namespace mlx::core {
|
32 |
+
typedef struct _MLX_BFloat16 bfloat16_t;
|
33 |
+
} // namespace mlx::core
|
34 |
+
|
35 |
+
#endif // __ARM_FEATURE_BF16
|
36 |
+
|
37 |
+
#ifdef ADD_HALF_BINOPS
|
38 |
+
namespace mlx::core {
|
39 |
+
|
40 |
+
// clang-format off
|
41 |
+
#define fp16_bf16_binop_helper(__op__, __operator__) \
|
42 |
+
inline float __operator__(float16_t lhs, bfloat16_t rhs) { \
|
43 |
+
return static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
44 |
+
} \
|
45 |
+
inline float __operator__(bfloat16_t lhs, float16_t rhs) { \
|
46 |
+
return static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
47 |
+
}
|
48 |
+
|
49 |
+
fp16_bf16_binop_helper(+, operator+)
|
50 |
+
fp16_bf16_binop_helper(-, operator-)
|
51 |
+
fp16_bf16_binop_helper(*, operator*)
|
52 |
+
fp16_bf16_binop_helper(/, operator/)
|
53 |
+
// clang-format on
|
54 |
+
|
55 |
+
} // namespace mlx::core
|
56 |
+
#endif
|
lib/python3.11/site-packages/mlx/include/mlx/utils.h
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include "array.h"
|
6 |
+
#include "device.h"
|
7 |
+
#include "dtype.h"
|
8 |
+
#include "stream.h"
|
9 |
+
|
10 |
+
namespace mlx::core {
|
11 |
+
|
12 |
+
/** The type from promoting the arrays' types with one another. */
|
13 |
+
Dtype result_type(const std::vector<array>& arrays);
|
14 |
+
|
15 |
+
std::vector<int> broadcast_shapes(
|
16 |
+
const std::vector<int>& s1,
|
17 |
+
const std::vector<int>& s2);
|
18 |
+
|
19 |
+
bool is_same_shape(const std::vector<array>& arrays);
|
20 |
+
|
21 |
+
/**
|
22 |
+
* Returns the axis normalized to be in the range [0, ndim).
|
23 |
+
* Based on numpy's normalize_axis_index. See
|
24 |
+
* https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html
|
25 |
+
*/
|
26 |
+
int normalize_axis(int axis, int ndim);
|
27 |
+
|
28 |
+
std::ostream& operator<<(std::ostream& os, const Device& d);
|
29 |
+
std::ostream& operator<<(std::ostream& os, const Stream& s);
|
30 |
+
std::ostream& operator<<(std::ostream& os, const Dtype& d);
|
31 |
+
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
|
32 |
+
std::ostream& operator<<(std::ostream& os, array a);
|
33 |
+
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
|
34 |
+
std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v);
|
35 |
+
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
|
36 |
+
return os << v.real() << (v.imag() > 0 ? "+" : "") << v.imag() << "j";
|
37 |
+
}
|
38 |
+
inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
|
39 |
+
return os << static_cast<float>(v);
|
40 |
+
}
|
41 |
+
inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
|
42 |
+
return os << static_cast<float>(v);
|
43 |
+
}
|
44 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/lib/libmlx.dylib
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8abefe46a1f39c92b28464814f05a730fa9899b17757703403c6ef362f06ac93
|
3 |
+
size 12420704
|
lib/python3.11/site-packages/mlx/lib/mlx.metallib
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2eedf41000ed270283da11d889bb101aa4c88c6f8f0ec68fe6b040a5be424501
|
3 |
+
size 59495531
|
lib/python3.11/site-packages/mlx/nn/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
from mlx.nn import losses
|
4 |
+
from mlx.nn.layers import *
|
5 |
+
from mlx.nn.utils import value_and_grad
|
lib/python3.11/site-packages/mlx/nn/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (381 Bytes). View file
|
|
lib/python3.11/site-packages/mlx/nn/__pycache__/losses.cpython-311.pyc
ADDED
Binary file (15.3 kB). View file
|
|
lib/python3.11/site-packages/mlx/nn/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (1.78 kB). View file
|
|
lib/python3.11/site-packages/mlx/nn/layers/__init__.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
from mlx.nn.layers.activations import (
|
4 |
+
CELU,
|
5 |
+
ELU,
|
6 |
+
GELU,
|
7 |
+
SELU,
|
8 |
+
Hardswish,
|
9 |
+
LeakyReLU,
|
10 |
+
LogSigmoid,
|
11 |
+
LogSoftmax,
|
12 |
+
Mish,
|
13 |
+
PReLU,
|
14 |
+
ReLU,
|
15 |
+
ReLU6,
|
16 |
+
SiLU,
|
17 |
+
Softmax,
|
18 |
+
Softplus,
|
19 |
+
Softsign,
|
20 |
+
Step,
|
21 |
+
Tanh,
|
22 |
+
celu,
|
23 |
+
elu,
|
24 |
+
gelu,
|
25 |
+
gelu_approx,
|
26 |
+
gelu_fast_approx,
|
27 |
+
hardswish,
|
28 |
+
leaky_relu,
|
29 |
+
log_sigmoid,
|
30 |
+
log_softmax,
|
31 |
+
mish,
|
32 |
+
prelu,
|
33 |
+
relu,
|
34 |
+
relu6,
|
35 |
+
selu,
|
36 |
+
silu,
|
37 |
+
softmax,
|
38 |
+
softplus,
|
39 |
+
softsign,
|
40 |
+
step,
|
41 |
+
tanh,
|
42 |
+
)
|
43 |
+
from mlx.nn.layers.base import Module
|
44 |
+
from mlx.nn.layers.containers import Sequential
|
45 |
+
from mlx.nn.layers.convolution import Conv1d, Conv2d
|
46 |
+
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
|
47 |
+
from mlx.nn.layers.embedding import Embedding
|
48 |
+
from mlx.nn.layers.linear import Bilinear, Identity, Linear
|
49 |
+
from mlx.nn.layers.normalization import (
|
50 |
+
BatchNorm,
|
51 |
+
GroupNorm,
|
52 |
+
InstanceNorm,
|
53 |
+
LayerNorm,
|
54 |
+
RMSNorm,
|
55 |
+
)
|
56 |
+
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
|
57 |
+
from mlx.nn.layers.quantized import QuantizedLinear
|
58 |
+
from mlx.nn.layers.transformer import (
|
59 |
+
MultiHeadAttention,
|
60 |
+
Transformer,
|
61 |
+
TransformerEncoder,
|
62 |
+
TransformerEncoderLayer,
|
63 |
+
)
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (2.52 kB). View file
|
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/activations.cpython-311.pyc
ADDED
Binary file (20.9 kB). View file
|
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/base.cpython-311.pyc
ADDED
Binary file (28.1 kB). View file
|
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/containers.cpython-311.pyc
ADDED
Binary file (1.47 kB). View file
|
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/convolution.cpython-311.pyc
ADDED
Binary file (6.36 kB). View file
|
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/dropout.cpython-311.pyc
ADDED
Binary file (6.71 kB). View file
|
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/embedding.cpython-311.pyc
ADDED
Binary file (2.07 kB). View file
|
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/linear.cpython-311.pyc
ADDED
Binary file (6.92 kB). View file
|
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/normalization.cpython-311.pyc
ADDED
Binary file (17.7 kB). View file
|
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/positional_encoding.cpython-311.pyc
ADDED
Binary file (10.6 kB). View file
|
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/quantized.cpython-311.pyc
ADDED
Binary file (6.34 kB). View file
|
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/transformer.cpython-311.pyc
ADDED
Binary file (18 kB). View file
|
|