clip-variants / variants.py
mlunar's picture
Add shape inference for textual models
b98d24d
import onnx
import os
import itertools
import argparse
import shutil
from onnxconverter_common.float16 import convert_float_to_float16
from onnxruntime.quantization import quantize_dynamic, QuantType
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
from multiprocessing import Pool
from tabulate import tabulate
def float32(input, output):
shutil.copy2(input, output)
def float16(input, output):
model = onnx.load(input)
model_f16 = convert_float_to_float16(model)
onnx.save(model_f16, output)
def qint8(input, output):
quantize_dynamic(input, output, weight_type=QuantType.QInt8)
def quint8(input, output):
quantize_dynamic(input, output, weight_type=QuantType.QUInt8)
def infer_shapes(input, output):
out_mp = SymbolicShapeInference.infer_shapes(onnx.load(input))
onnx.save(out_mp, output)
def print_table(table):
print(tabulate(table, headers="keys", tablefmt="github"), "\n")
def get_file_mb(path):
try:
stat = os.stat(path)
except FileNotFoundError:
return "N/A"
mb = round(stat.st_size / 1_000_000)
return f"{mb}"
def convert(name, mode, f, markdown):
fname = f.__name__
input = f"converted/clip-{name}-{mode}.onnx"
output = f"models/clip-{name}-{mode}-{fname}.onnx"
exists = os.path.exists(output)
if markdown:
return [output, name, mode, fname, "✅" if exists else "❌"]
if exists:
print(f"{output} exists")
else:
if mode == "textual":
output_temp = f"{output}.temp"
print(f"{output} converting")
f(input, output_temp)
print(f"{output} running shape inference for TensorRT support")
infer_shapes(output_temp, output)
os.remove(output_temp)
else:
print(f"{output} converting")
f(input, output)
print(f"{output} done")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Create variants of converted models')
parser.add_argument(
'--markdown',
action='store_true',
help='Print markdown tables describing the variants'
)
args = parser.parse_args()
names = [
"resnet-50",
"resnet-101",
"resnet-50x4",
"resnet-50x16",
"resnet-50x64",
"resnet-50",
"resnet-50",
"resnet-50",
"vit-base-patch16",
"vit-base-patch32",
"vit-large-patch14",
"vit-large-patch14-336",
]
modes = [
"visual",
"textual"
]
funcs = [
float32,
float16,
qint8,
quint8,
]
markdown = args.markdown
if markdown:
print_table({ "Model ID": names })
print_table({ "Mode": modes })
print_table({ "Data Type": [f.__name__ for f in funcs] })
variants = itertools.product(names, modes, funcs, [markdown])
with Pool(8 if not markdown else 1) as p:
variants_table = p.starmap(convert, variants)
if markdown:
for row in variants_table:
output = row[0]
file_size = get_file_mb(output)
row.append(file_size)
variants_table.insert(0, ["Path", "Model ID", "Mode", "Data Type", "Available", "Size (MB)"])
print(tabulate(variants_table, headers="firstrow", tablefmt="github"))
else:
print("done")