|
import onnx |
|
import os |
|
import itertools |
|
import argparse |
|
from onnxconverter_common.float16 import convert_float_to_float16 |
|
from onnxruntime.quantization import quantize_dynamic, QuantType |
|
from multiprocessing import Pool |
|
from tabulate import tabulate |
|
|
|
def float16(input, output): |
|
model = onnx.load(input) |
|
model_f16 = convert_float_to_float16(model) |
|
onnx.save(model_f16, output) |
|
|
|
def qint8(input, output): |
|
quantize_dynamic(input, output, weight_type=QuantType.QInt8) |
|
|
|
def quint8(input, output): |
|
quantize_dynamic(input, output, weight_type=QuantType.QUInt8) |
|
|
|
def print_table(table): |
|
print(tabulate(table, headers="keys", tablefmt="github"), "\n") |
|
|
|
def get_file_mb(path): |
|
try: |
|
stat = os.stat(path) |
|
except FileNotFoundError: |
|
return "N/A" |
|
mb = round(stat.st_size / 1_000_000) |
|
return f"{mb}" |
|
|
|
def convert(name, mode, f, markdown): |
|
fname = f.__name__ |
|
input = f"models/clip-{name}-{mode}.onnx" |
|
output = f"models/clip-{name}-{mode}-{fname}.onnx" |
|
exists = os.path.exists(output) |
|
if exists: |
|
if not markdown: |
|
print(f"{output} exists") |
|
else: |
|
if not markdown: |
|
print(f"{output} converting") |
|
f(input, output) |
|
if not markdown: |
|
print(f"{output} done") |
|
return [input, output, name, mode, fname, "β
" if exists else "β"] |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description='Create variants of converted models') |
|
parser.add_argument( |
|
'--markdown', |
|
action='store_true', |
|
help='Print markdown tables describing the variants' |
|
) |
|
args = parser.parse_args() |
|
names = [ |
|
"resnet-50", |
|
"resnet-101", |
|
"resnet-50x4", |
|
"resnet-50x16", |
|
"resnet-50x64", |
|
"resnet-50", |
|
"resnet-50", |
|
"resnet-50", |
|
"vit-base-patch16", |
|
"vit-base-patch32", |
|
"vit-large-patch14", |
|
"vit-large-patch14-336", |
|
] |
|
modes = [ |
|
"visual", |
|
"textual" |
|
] |
|
funcs = [ |
|
float16, |
|
qint8, |
|
quint8, |
|
] |
|
markdown = args.markdown |
|
if markdown: |
|
print_table({ "Model ID": names }) |
|
print_table({ "Mode": modes }) |
|
print_table({ "Data Type": [f.__name__ for f in funcs] }) |
|
variants = itertools.product(names, modes, funcs, [markdown]) |
|
|
|
with Pool(8 if not markdown else 1) as p: |
|
variants_table = p.starmap(convert, variants) |
|
if markdown: |
|
|
|
prev_input = "" |
|
variants_table_with_originals = [] |
|
for row in variants_table: |
|
input = row[0] |
|
output = row[1] |
|
if input != prev_input: |
|
prev_input = input |
|
variants_table_with_originals.append( |
|
row[0:1] + row[2:4] + ["float32 (original)", "β
", get_file_mb(input)] |
|
) |
|
file_size = get_file_mb(output) |
|
variants_table_with_originals.append(row[1:] + [file_size]) |
|
|
|
variants_table_with_originals.insert(0, ["Path", "Model ID", "Mode", "Data Type", "Available", "Size (MB)"]) |
|
|
|
print(tabulate(variants_table_with_originals, headers="firstrow", tablefmt="github")) |
|
else: |
|
print("done") |
|
|
|
|