|
|
import onnx |
|
|
from onnx import helper, TensorProto |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
def optimize_sa_model(model_path: str, out_path: str, is_fp16: bool = False): |
|
|
model = onnx.load(model_path) |
|
|
|
|
|
graph = model.graph |
|
|
|
|
|
old_input_name="orig_im_size" |
|
|
new_input_name="orig_im_size_shape" |
|
|
|
|
|
|
|
|
old_inputs = {vi.name: vi for vi in graph.input} |
|
|
assert old_input_name in old_inputs, f"Input {old_input_name} not found" |
|
|
old_vi = old_inputs[old_input_name] |
|
|
|
|
|
|
|
|
graph.input.remove(old_vi) |
|
|
|
|
|
new_input_vi = helper.make_tensor_value_info(new_input_name, TensorProto.FLOAT, ["height", "width"]) |
|
|
graph.input.extend([new_input_vi]) |
|
|
|
|
|
|
|
|
if new_input_name not in [input.name for input in graph.input]: |
|
|
raise ValueError(f"Input '{new_input_name}' does not exist in the graph inputs.") |
|
|
|
|
|
|
|
|
shape_output_name = old_input_name |
|
|
shape_node = helper.make_node( |
|
|
"Shape", |
|
|
inputs=[new_input_name], |
|
|
outputs=[shape_output_name], |
|
|
name="shape_of_orig_im_size_shape" |
|
|
) |
|
|
|
|
|
graph.node.insert(0, shape_node) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_fp16: |
|
|
|
|
|
|
|
|
fp16_constants = ["/Constant_85", "/Constant_86"] |
|
|
|
|
|
|
|
|
for node in graph.node: |
|
|
if node.op_type == "Constant" and node.name in fp16_constants: |
|
|
print(node.name) |
|
|
|
|
|
for attr in node.attribute: |
|
|
if attr.name == "value": |
|
|
|
|
|
tensor = onnx.numpy_helper.to_array(attr.t) |
|
|
|
|
|
|
|
|
new_tensor = tensor.astype(np.float32) |
|
|
|
|
|
|
|
|
attr.t.CopyFrom(onnx.numpy_helper.from_array(new_tensor)) |
|
|
break |
|
|
else: |
|
|
raise ValueError(f"Constant node '{node.name}' does not have a 'value' attribute.") |
|
|
|
|
|
fp16_nodes = ["/ReduceMax", "/Reciprocal", "/Mul_19", "/Mul_20", "/Add_11", "/Floor"] |
|
|
|
|
|
for node in graph.node: |
|
|
if node.name in fp16_nodes: |
|
|
print(f"Processing node: {node.name}") |
|
|
for input_name in node.input: |
|
|
for value_info in graph.value_info: |
|
|
if value_info.name == input_name: |
|
|
value_info.type.tensor_type.elem_type = TensorProto.FLOAT |
|
|
print(f" - Change input: {input_name} to fp32") |
|
|
|
|
|
for output_name in node.output: |
|
|
for value_info in graph.value_info: |
|
|
if value_info.name == output_name: |
|
|
value_info.type.tensor_type.elem_type = TensorProto.FLOAT |
|
|
print(f" - Change output: {output_name} to fp32") |
|
|
|
|
|
|
|
|
for node in graph.node: |
|
|
if node.name == "/Cast_9": |
|
|
node.attribute[0].i = TensorProto.FLOAT |
|
|
print(f"Changed /Cast_9 to fp32") |
|
|
break |
|
|
onnx.checker.check_model(model) |
|
|
onnx.save(model, out_path) |
|
|
print(f"Saved to {out_path}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimize_sa_model("sam_vit_b_01ec64.decoder-fp16.onnx", "sam_vit_b_01ec64.decoder-orig-img-size-dynamic-fp16.onnx", True) |