File size: 3,406 Bytes
6a3ad5b
 
 
 
b98d24d
6a3ad5b
 
b98d24d
6a3ad5b
 
 
b98d24d
 
 
6a3ad5b
 
 
 
 
 
 
 
 
 
 
b98d24d
 
 
 
6a3ad5b
 
 
 
 
 
 
 
 
 
 
 
 
b98d24d
6a3ad5b
 
b98d24d
 
6a3ad5b
b98d24d
6a3ad5b
b98d24d
 
6a3ad5b
b98d24d
 
 
 
 
 
 
 
6a3ad5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b98d24d
6a3ad5b
 
 
 
 
 
 
 
 
 
 
 
 
 
b98d24d
6a3ad5b
b98d24d
 
 
6a3ad5b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import onnx
import os
import itertools
import argparse
import shutil
from onnxconverter_common.float16 import convert_float_to_float16
from onnxruntime.quantization import quantize_dynamic, QuantType
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
from multiprocessing import Pool
from tabulate import tabulate

def float32(input, output):
    shutil.copy2(input, output)

def float16(input, output):
    model = onnx.load(input)
    model_f16 = convert_float_to_float16(model)
    onnx.save(model_f16, output)

def qint8(input, output):
    quantize_dynamic(input, output, weight_type=QuantType.QInt8)

def quint8(input, output):
    quantize_dynamic(input, output, weight_type=QuantType.QUInt8)

def infer_shapes(input, output):
    out_mp = SymbolicShapeInference.infer_shapes(onnx.load(input))
    onnx.save(out_mp, output)

def print_table(table):
    print(tabulate(table, headers="keys", tablefmt="github"), "\n")

def get_file_mb(path):
    try:
        stat = os.stat(path)
    except FileNotFoundError:
        return "N/A"
    mb = round(stat.st_size / 1_000_000)
    return f"{mb}"

def convert(name, mode, f, markdown):
    fname = f.__name__
    input = f"converted/clip-{name}-{mode}.onnx"
    output = f"models/clip-{name}-{mode}-{fname}.onnx"
    exists = os.path.exists(output)
    if markdown:
        return [output, name, mode, fname, "✅" if exists else "❌"]
    if exists:
        print(f"{output} exists")
    else:
        if mode == "textual":
            output_temp = f"{output}.temp"
            print(f"{output} converting")
            f(input, output_temp)
            print(f"{output} running shape inference for TensorRT support")
            infer_shapes(output_temp, output)
            os.remove(output_temp)
        else:
            print(f"{output} converting")
            f(input, output)
        print(f"{output} done")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Create variants of converted models')
    parser.add_argument(
        '--markdown',
        action='store_true',
        help='Print markdown tables describing the variants'
    )
    args = parser.parse_args()
    names = [
        "resnet-50",
        "resnet-101",
        "resnet-50x4",
        "resnet-50x16",
        "resnet-50x64",
        "resnet-50",
        "resnet-50",
        "resnet-50",
        "vit-base-patch16",
        "vit-base-patch32",
        "vit-large-patch14",
        "vit-large-patch14-336",
    ]
    modes = [
        "visual",
        "textual"
    ]
    funcs = [
        float32,
        float16,
        qint8,
        quint8,
    ]
    markdown = args.markdown
    if markdown:
        print_table({ "Model ID": names })
        print_table({ "Mode": modes })
        print_table({ "Data Type": [f.__name__ for f in funcs] })
    variants = itertools.product(names, modes, funcs, [markdown])
    with Pool(8 if not markdown else 1) as p:
        variants_table = p.starmap(convert, variants)
        if markdown:
            for row in variants_table:
                output = row[0]
                file_size = get_file_mb(output)
                row.append(file_size)
            variants_table.insert(0, ["Path", "Model ID", "Mode", "Data Type", "Available", "Size (MB)"])
            print(tabulate(variants_table, headers="firstrow", tablefmt="github"))
        else:
            print("done")