|
import os |
|
import onnx |
|
import shutil |
|
import zipfile |
|
import argparse |
|
|
|
|
|
def create_triton_config(model_path, config_path, model_name, max_batch_size=0): |
|
|
|
model = onnx.load(model_path) |
|
|
|
input_tensors = [] |
|
for i in model.graph.input: |
|
shape = [dim.dim_value if dim.dim_value >= 1 else -1 for dim in i.type.tensor_type.shape.dim][1:] |
|
input_tensors.append({"name": i.name, "data_type": "TYPE_FP32", "dims": shape}) |
|
output_tensors = [] |
|
for o in model.graph.output: |
|
shape = [dim.dim_value if dim.dim_value >= 1 else -1 for dim in o.type.tensor_type.shape.dim] |
|
|
|
config = { |
|
"name": model_name, |
|
"backend": "onnxruntime", |
|
"max_batch_size": max_batch_size, |
|
"input": input_tensors, |
|
"output": output_tensors, |
|
"instance_group": [{"count": 1, "kind": "KIND_CPU"}], |
|
} |
|
|
|
with open(config_path, 'w') as f: |
|
f.write("name: \"" + config['name'] + "\"\n") |
|
f.write("backend: \"" + config['backend'] + "\"\n") |
|
f.write("max_batch_size: " + str(config['max_batch_size']) + "\n") |
|
f.write("input [\n") |
|
for input_tensor in config['input']: |
|
f.write(" {\n") |
|
f.write(" name: \"" + input_tensor['name'] + "\"\n") |
|
f.write(" data_type: " + input_tensor['data_type'] + "\n") |
|
f.write(" dims: [ " + ", ".join([str(dim) for dim in input_tensor['dims']]) + " ]\n") |
|
f.write(" }\n") |
|
f.write("]\n") |
|
f.write("output [\n") |
|
for output_tensor in config['output']: |
|
f.write(" {\n") |
|
f.write(" name: \"" + output_tensor['name'] + "\"\n") |
|
f.write(" data_type: " + output_tensor['data_type'] + "\n") |
|
f.write(" dims: [ " + ", ".join([str(dim) for dim in output_tensor['dims']]) + " ]\n") |
|
f.write(" }\n") |
|
f.write("]\n") |
|
f.write("instance_group [\n") |
|
for instance_group in config['instance_group']: |
|
f.write(" {\n") |
|
f.write(" count: " + str(instance_group['count']) + "\n") |
|
f.write(" kind: " + instance_group['kind'] + "\n") |
|
f.write(" }\n") |
|
f.write("]\n") |
|
print(f"The configuration file has been saved to '{config_path}'") |
|
|
|
|
|
def list_onnx_files(directory): |
|
onnx_files = [] |
|
for root, _, files in os.walk(directory): |
|
for file in files: |
|
if file.endswith(".onnx"): |
|
onnx_files.append(file) |
|
return onnx_files |
|
|
|
|
|
def zip_folder(folder_path, output_path): |
|
|
|
with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zip_file: |
|
|
|
for root, dirs, files in os.walk(folder_path): |
|
for file in files: |
|
file_path = os.path.join(root, file) |
|
zip_file.write(file_path, arcname=os.path.relpath(file_path, folder_path)) |
|
|
|
for dir in dirs: |
|
dir_path = os.path.join(root, dir) |
|
zip_file.write(dir_path, arcname=os.path.relpath(dir_path, folder_path)) |
|
print(f"The folder '{folder_path}' has been zipped to '{output_path}'") |
|
|
|
|
|
""" |
|
python model_packaging.py adaboost_regressor |
|
""" |
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser(description='Process a directory with ONNX models.') |
|
parser.add_argument('folder_path', type=str, help='Path to the directory with ONNX models.') |
|
args = parser.parse_args() |
|
folder_path = args.folder_path |
|
|
|
os.chdir(folder_path) |
|
filenames = list_onnx_files(folder_path) |
|
version = '1' |
|
print(filenames) |
|
for filename in filenames: |
|
if not filename.startswith("."): |
|
foldername = os.path.splitext(filename)[0] |
|
if not os.path.exists(foldername): |
|
os.makedirs(foldername, exist_ok=True) |
|
folderdir = os.path.join(foldername, version) |
|
os.makedirs(folderdir, exist_ok=True) |
|
shutil.copy(filename, folderdir) |
|
model_path = os.path.join(folderdir, filename) |
|
config_path = os.path.join(foldername, "config.pbtxt") |
|
create_triton_config(model_path, config_path, foldername, max_batch_size=0) |
|
os.rename(os.path.join(folderdir, filename), os.path.join(folderdir, 'model.onnx')) |
|
print(f"The file '{os.path.join(folderdir, filename)}' has been renamed to '{os.path.join(folderdir, 'model.onnx')}'") |
|
print(f"{foldername} folder created successfully!") |
|
print(f"{filename} copied to {foldername} successfully!") |
|
zip_folder(os.path.join(folder_path, foldername), f"{foldername}.zip") |
|
else: |
|
print(f"{foldername} folder already exists.") |
|
|