models-repository / scikit-learn /convert2onnx.py
andrewssobral's picture
Improve scripts
c1e1349
raw
history blame
2.55 kB
import argparse
import joblib
from sklearn.datasets import fetch_california_housing, load_diabetes, load_iris
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
def load_dataset(dataset_name):
if dataset_name == 'california':
dataset = fetch_california_housing()
elif dataset_name == 'diabetes':
dataset = load_diabetes()
elif dataset_name == 'iris':
dataset = load_iris()
else:
raise ValueError("Invalid dataset name")
return dataset.data, dataset.target
def prepare_onnx_conversion_params(X, target_opset, model):
if target_opset in {9, 17}:
tensor = FloatTensorType([None, 4])
else:
tensor = FloatTensorType([None, X.shape[1]])
if target_opset == 9:
options = {id(model): {'zipmap': False}}
else:
options = None
return tensor, options
def convert2onnx(model, initial_type, options, target_opset, onnx_filename):
try:
model.save_model(onnx_filename, format="onnx")
print("Model saved in ONNX format successfully.")
except Exception as e:
print("Error occurred while saving model in ONNX format:", e)
print("Trying a second method...")
try:
onnx_model = convert_sklearn(model, initial_types=initial_type, options=options, target_opset=target_opset)
with open(onnx_filename, "wb") as f:
f.write(onnx_model.SerializeToString())
print("Model converted to ONNX format and saved successfully.")
except Exception as e:
print("Error occurred while converting model to ONNX format:", e)
"""
python convert2onnx.py california adaboost_regressor.joblib adaboost_regressor.onnx
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Converts a sklearn model to ONNX format.')
parser.add_argument('dataset_name', type=str, help='Name of the dataset. Choose from: "california", "diabetes", or "iris".')
parser.add_argument('model_path', type=str, help='Path to the trained model file.')
parser.add_argument('onnx_filename', type=str, help='The filename for the output ONNX file.')
args = parser.parse_args()
X, y = load_dataset(args.dataset_name)
model = joblib.load(args.model_path)
target_opset = 12
tensor, options = prepare_onnx_conversion_params(X, target_opset, model)
input_name = 'float_input'
initial_type = [(input_name, tensor)]
convert2onnx(model, initial_type, options, target_opset, args.onnx_filename)