import onnx import os import itertools import argparse import shutil from onnxconverter_common.float16 import convert_float_to_float16 from onnxruntime.quantization import quantize_dynamic, QuantType from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference from multiprocessing import Pool from tabulate import tabulate def float32(input, output): shutil.copy2(input, output) def float16(input, output): model = onnx.load(input) model_f16 = convert_float_to_float16(model) onnx.save(model_f16, output) def qint8(input, output): quantize_dynamic(input, output, weight_type=QuantType.QInt8) def quint8(input, output): quantize_dynamic(input, output, weight_type=QuantType.QUInt8) def infer_shapes(input, output): out_mp = SymbolicShapeInference.infer_shapes(onnx.load(input)) onnx.save(out_mp, output) def print_table(table): print(tabulate(table, headers="keys", tablefmt="github"), "\n") def get_file_mb(path): try: stat = os.stat(path) except FileNotFoundError: return "N/A" mb = round(stat.st_size / 1_000_000) return f"{mb}" def convert(name, mode, f, markdown): fname = f.__name__ input = f"converted/clip-{name}-{mode}.onnx" output = f"models/clip-{name}-{mode}-{fname}.onnx" exists = os.path.exists(output) if markdown: return [output, name, mode, fname, "✅" if exists else "❌"] if exists: print(f"{output} exists") else: if mode == "textual": output_temp = f"{output}.temp" print(f"{output} converting") f(input, output_temp) print(f"{output} running shape inference for TensorRT support") infer_shapes(output_temp, output) os.remove(output_temp) else: print(f"{output} converting") f(input, output) print(f"{output} done") if __name__ == '__main__': parser = argparse.ArgumentParser(description='Create variants of converted models') parser.add_argument( '--markdown', action='store_true', help='Print markdown tables describing the variants' ) args = parser.parse_args() names = [ "resnet-50", "resnet-101", "resnet-50x4", "resnet-50x16", "resnet-50x64", "resnet-50", "resnet-50", "resnet-50", "vit-base-patch16", "vit-base-patch32", "vit-large-patch14", "vit-large-patch14-336", ] modes = [ "visual", "textual" ] funcs = [ float32, float16, qint8, quint8, ] markdown = args.markdown if markdown: print_table({ "Model ID": names }) print_table({ "Mode": modes }) print_table({ "Data Type": [f.__name__ for f in funcs] }) variants = itertools.product(names, modes, funcs, [markdown]) with Pool(8 if not markdown else 1) as p: variants_table = p.starmap(convert, variants) if markdown: for row in variants_table: output = row[0] file_size = get_file_mb(output) row.append(file_size) variants_table.insert(0, ["Path", "Model ID", "Mode", "Data Type", "Available", "Size (MB)"]) print(tabulate(variants_table, headers="firstrow", tablefmt="github")) else: print("done")