mwitiderrick
commited on
Commit
·
d9b2258
1
Parent(s):
08d7494
Create onnx_kv_inject.py
Browse files- onnx_kv_inject.py +17 -0
onnx_kv_inject.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import click
|
2 |
+
import os
|
3 |
+
import onnx
|
4 |
+
from sparseml.exporters.kv_cache_injector import KeyValueCacheInjector
|
5 |
+
from sparseml.onnx.utils import ONNXGraph
|
6 |
+
@click.command()
|
7 |
+
@click.option('--input-file', help='Path to the input ONNX model file')
|
8 |
+
@click.option('--output-file', help='Output path for the modified model')
|
9 |
+
def modify_model(input_file, output_file):
|
10 |
+
model = onnx.load(input_file, load_external_data=False)
|
11 |
+
model = KeyValueCacheInjector(model_path=os.path.dirname(input_file)).apply(model)
|
12 |
+
graph = ONNXGraph(model)
|
13 |
+
graph.delete_orphaned_node_branches()
|
14 |
+
onnx.save(model, output_file)
|
15 |
+
print(f"Modified model saved to: {output_file}")
|
16 |
+
if __name__ == '__main__':
|
17 |
+
modify_model()
|