smpanaro commited on
Commit
07183f4
·
1 Parent(s): dba673f

Add new joint prefill+generation cache processor

Browse files
cache-processor.mlmodelc/analytics/coremldata.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12adcd9eb610f08f550a8914fe8c71a5748b2ad96ca8e099ffe4ad2a40ded079
3
+ size 243
cache-processor.mlmodelc/coremldata.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9b26e05e4ae0f96f84c27589c4f0466340377ba6827456f8fcc6414539b5718
3
+ size 516
cache-processor.mlmodelc/metadata.json ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "metadataOutputVersion" : "3.0",
4
+ "outputSchema" : [
5
+ {
6
+ "hasShapeFlexibility" : "0",
7
+ "isOptional" : "0",
8
+ "dataType" : "Float16",
9
+ "formattedType" : "MultiArray (Float16 1 × 448 × 1 × 4096)",
10
+ "shortDescription" : "",
11
+ "shape" : "[1, 448, 1, 4096]",
12
+ "name" : "updated_k_cache",
13
+ "type" : "MultiArray"
14
+ },
15
+ {
16
+ "hasShapeFlexibility" : "0",
17
+ "isOptional" : "0",
18
+ "dataType" : "Float16",
19
+ "formattedType" : "MultiArray (Float16 1 × 4096 × 1 × 448)",
20
+ "shortDescription" : "",
21
+ "shape" : "[1, 4096, 1, 448]",
22
+ "name" : "updated_v_cache",
23
+ "type" : "MultiArray"
24
+ },
25
+ {
26
+ "hasShapeFlexibility" : "0",
27
+ "isOptional" : "0",
28
+ "dataType" : "Float16",
29
+ "formattedType" : "MultiArray (Float16)",
30
+ "shortDescription" : "",
31
+ "shape" : "[]",
32
+ "name" : "ignore_me_im_only_here_so_this_runs_on_the_ane",
33
+ "type" : "MultiArray"
34
+ }
35
+ ],
36
+ "modelParameters" : [
37
+
38
+ ],
39
+ "specificationVersion" : 7,
40
+ "mlProgramOperationTypeHistogram" : {
41
+ "SliceByIndex" : 2,
42
+ "Ios16.mul" : 1,
43
+ "Concat" : 2,
44
+ "Ios16.reduceMin" : 1
45
+ },
46
+ "computePrecision" : "Mixed (Float16, Int32)",
47
+ "isUpdatable" : "0",
48
+ "availability" : {
49
+ "macOS" : "13.0",
50
+ "tvOS" : "16.0",
51
+ "visionOS" : "1.0",
52
+ "watchOS" : "9.0",
53
+ "iOS" : "16.0",
54
+ "macCatalyst" : "16.0"
55
+ },
56
+ "modelType" : {
57
+ "name" : "MLModelType_mlProgram"
58
+ },
59
+ "userDefinedMetadata" : {
60
+ "com.github.apple.coremltools.source_dialect" : "TorchScript",
61
+ "com.github.apple.coremltools.source" : "torch==2.1.0",
62
+ "com.github.apple.coremltools.version" : "8.0b1"
63
+ },
64
+ "inputSchema" : [
65
+ {
66
+ "hasShapeFlexibility" : "0",
67
+ "isOptional" : "0",
68
+ "dataType" : "Float16",
69
+ "formattedType" : "MultiArray (Float16 1 × 448 × 1 × 4096)",
70
+ "shortDescription" : "",
71
+ "shape" : "[1, 448, 1, 4096]",
72
+ "name" : "old_k_cache",
73
+ "type" : "MultiArray"
74
+ },
75
+ {
76
+ "hasShapeFlexibility" : "0",
77
+ "isOptional" : "0",
78
+ "dataType" : "Float16",
79
+ "formattedType" : "MultiArray (Float16 1 × 64 × 1 × 4096)",
80
+ "shortDescription" : "",
81
+ "shape" : "[1, 64, 1, 4096]",
82
+ "name" : "new_k_cache",
83
+ "type" : "MultiArray"
84
+ },
85
+ {
86
+ "hasShapeFlexibility" : "0",
87
+ "isOptional" : "0",
88
+ "dataType" : "Float16",
89
+ "formattedType" : "MultiArray (Float16 1 × 4096 × 1 × 448)",
90
+ "shortDescription" : "",
91
+ "shape" : "[1, 4096, 1, 448]",
92
+ "name" : "old_v_cache",
93
+ "type" : "MultiArray"
94
+ },
95
+ {
96
+ "hasShapeFlexibility" : "0",
97
+ "isOptional" : "0",
98
+ "dataType" : "Float16",
99
+ "formattedType" : "MultiArray (Float16 1 × 4096 × 1 × 64)",
100
+ "shortDescription" : "",
101
+ "shape" : "[1, 4096, 1, 64]",
102
+ "name" : "new_v_cache",
103
+ "type" : "MultiArray"
104
+ }
105
+ ],
106
+ "generatedClassName" : "cache_processor",
107
+ "method" : "predict"
108
+ }
109
+ ]
cache-processor.mlmodelc/model.mil ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ program(1.0)
2
+ [buildInfo = dict<tensor<string, []>, tensor<string, []>>({{"coremlc-component-MIL", "3304.5.2"}, {"coremlc-version", "3304.6.2"}, {"coremltools-component-torch", "2.1.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "8.0b1"}})]
3
+ {
4
+ func main<ios16>(tensor<fp16, [1, 64, 1, 4096]> new_k_cache, tensor<fp16, [1, 4096, 1, 64]> new_v_cache, tensor<fp16, [1, 448, 1, 4096]> old_k_cache, tensor<fp16, [1, 4096, 1, 448]> old_v_cache) {
5
+ tensor<int32, []> var_6 = const()[name = tensor<string, []>("op_6"), val = tensor<int32, []>(-3)];
6
+ tensor<bool, []> cat_k_1_interleave_0 = const()[name = tensor<string, []>("cat_k_1_interleave_0"), val = tensor<bool, []>(false)];
7
+ tensor<fp16, [1, 512, 1, 4096]> cat_k_1_cast_fp16 = concat(axis = var_6, interleave = cat_k_1_interleave_0, values = (old_k_cache, new_k_cache))[name = tensor<string, []>("cat_k_1_cast_fp16")];
8
+ tensor<int32, []> var_9 = const()[name = tensor<string, []>("op_9"), val = tensor<int32, []>(-1)];
9
+ tensor<bool, []> cat_v_interleave_0 = const()[name = tensor<string, []>("cat_v_interleave_0"), val = tensor<bool, []>(false)];
10
+ tensor<fp16, [1, 4096, 1, 512]> cat_v_cast_fp16 = concat(axis = var_9, interleave = cat_v_interleave_0, values = (old_v_cache, new_v_cache))[name = tensor<string, []>("cat_v_cast_fp16")];
11
+ tensor<int32, [4]> var_20_begin_0 = const()[name = tensor<string, []>("op_20_begin_0"), val = tensor<int32, [4]>([0, 64, 0, 0])];
12
+ tensor<int32, [4]> var_20_end_0 = const()[name = tensor<string, []>("op_20_end_0"), val = tensor<int32, [4]>([1, 512, 1, 4096])];
13
+ tensor<bool, [4]> var_20_end_mask_0 = const()[name = tensor<string, []>("op_20_end_mask_0"), val = tensor<bool, [4]>([true, false, true, true])];
14
+ tensor<fp16, [1, 448, 1, 4096]> updated_k_cache = slice_by_index(begin = var_20_begin_0, end = var_20_end_0, end_mask = var_20_end_mask_0, x = cat_k_1_cast_fp16)[name = tensor<string, []>("op_20_cast_fp16")];
15
+ tensor<int32, [4]> var_50_begin_0 = const()[name = tensor<string, []>("op_50_begin_0"), val = tensor<int32, [4]>([0, 0, 0, 64])];
16
+ tensor<int32, [4]> var_50_end_0 = const()[name = tensor<string, []>("op_50_end_0"), val = tensor<int32, [4]>([1, 4096, 1, 512])];
17
+ tensor<bool, [4]> var_50_end_mask_0 = const()[name = tensor<string, []>("op_50_end_mask_0"), val = tensor<bool, [4]>([true, true, true, false])];
18
+ tensor<fp16, [1, 4096, 1, 448]> updated_v_cache = slice_by_index(begin = var_50_begin_0, end = var_50_end_0, end_mask = var_50_end_mask_0, x = cat_v_cast_fp16)[name = tensor<string, []>("op_50_cast_fp16")];
19
+ tensor<fp16, []> var_51_promoted_to_fp16 = const()[name = tensor<string, []>("op_51_promoted_to_fp16"), val = tensor<fp16, []>(0x1p+1)];
20
+ tensor<fp16, [1, 448, 1, 4096]> prod_cast_fp16 = mul(x = updated_k_cache, y = var_51_promoted_to_fp16)[name = tensor<string, []>("prod_cast_fp16")];
21
+ tensor<bool, []> var_53_keep_dims_0 = const()[name = tensor<string, []>("op_53_keep_dims_0"), val = tensor<bool, []>(false)];
22
+ tensor<fp16, []> ignore_me_im_only_here_so_this_runs_on_the_ane = reduce_min(keep_dims = var_53_keep_dims_0, x = prod_cast_fp16)[name = tensor<string, []>("op_53_cast_fp16")];
23
+ } -> (updated_k_cache, updated_v_cache, ignore_me_im_only_here_so_this_runs_on_the_ane);
24
+ }