|
import onnx_graphsurgeon as gs |
|
import onnx |
|
import numpy as np |
|
|
|
|
|
graph = gs.import_onnx(onnx.load("check3_fuse_ops.onnx")) |
|
|
|
count=0 |
|
|
|
for node in graph.nodes: |
|
|
|
if node.op == 'Reshape': |
|
|
|
shape_input = node.inputs[1] |
|
|
|
|
|
if isinstance(shape_input, gs.Constant): |
|
current_shape = shape_input.values |
|
|
|
|
|
if len(current_shape) == 5 and current_shape[0] == 12 and current_shape[1] == 64 and current_shape[2] == 64: |
|
|
|
new_shape = np.array([12, 4096, current_shape[3], current_shape[4]], dtype=np.int64) |
|
print(f"Patched {current_shape} -> {new_shape}") |
|
|
|
|
|
shape_input.values = new_shape |
|
count = count + 1 |
|
|
|
|
|
|
|
|
|
if len(current_shape) == 5 and current_shape[0] == 300 and current_shape[1] == 14 and current_shape[2] == 14: |
|
|
|
new_shape = np.array([300, 196, current_shape[3], current_shape[4]], dtype=np.int64) |
|
print(f"Patched {current_shape} -> {new_shape}") |
|
|
|
|
|
shape_input.values = new_shape |
|
count = count + 1 |
|
|
|
|
|
graph.cleanup().toposort() |
|
print(f"Patched {count} nodes.") |
|
|
|
model = gs.export_onnx(graph) |
|
|
|
|
|
for value_info in model.graph.value_info: |
|
value_info.type.tensor_type.ClearField('shape') |
|
|
|
|
|
onnx.save(model, "sam_vit_b_01ec64.pth.encoder.patched.onnx") |
|
|
|
print("Saved as 'sam_vit_b_01ec64.pth.encoder.patched.onnx'") |