|
import argparse |
|
import joblib |
|
|
|
from sklearn.datasets import fetch_california_housing, load_diabetes, load_iris |
|
from skl2onnx import convert_sklearn |
|
from skl2onnx.common.data_types import FloatTensorType |
|
|
|
|
|
def load_dataset(dataset_name): |
|
if dataset_name == 'california': |
|
dataset = fetch_california_housing() |
|
elif dataset_name == 'diabetes': |
|
dataset = load_diabetes() |
|
elif dataset_name == 'iris': |
|
dataset = load_iris() |
|
else: |
|
raise ValueError("Invalid dataset name") |
|
return dataset.data, dataset.target |
|
|
|
|
|
def prepare_onnx_conversion_params(X, target_opset, model): |
|
if target_opset in {9, 17}: |
|
tensor = FloatTensorType([None, 4]) |
|
else: |
|
tensor = FloatTensorType([None, X.shape[1]]) |
|
if target_opset == 9: |
|
options = {id(model): {'zipmap': False}} |
|
else: |
|
options = None |
|
return tensor, options |
|
|
|
|
|
def convert2onnx(model, initial_type, options, target_opset, onnx_filename): |
|
try: |
|
model.save_model(onnx_filename, format="onnx") |
|
print("Model saved in ONNX format successfully.") |
|
except Exception as e: |
|
print("Error occurred while saving model in ONNX format:", e) |
|
print("Trying a second method...") |
|
try: |
|
onnx_model = convert_sklearn(model, initial_types=initial_type, options=options, target_opset=target_opset) |
|
with open(onnx_filename, "wb") as f: |
|
f.write(onnx_model.SerializeToString()) |
|
print("Model converted to ONNX format and saved successfully.") |
|
except Exception as e: |
|
print("Error occurred while converting model to ONNX format:", e) |
|
|
|
""" |
|
python convert2onnx.py california adaboost_regressor.joblib adaboost_regressor.onnx |
|
""" |
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description='Converts a sklearn model to ONNX format.') |
|
parser.add_argument('dataset_name', type=str, help='Name of the dataset. Choose from: "california", "diabetes", or "iris".') |
|
parser.add_argument('model_path', type=str, help='Path to the trained model file.') |
|
parser.add_argument('onnx_filename', type=str, help='The filename for the output ONNX file.') |
|
args = parser.parse_args() |
|
|
|
X, y = load_dataset(args.dataset_name) |
|
model = joblib.load(args.model_path) |
|
target_opset = 12 |
|
tensor, options = prepare_onnx_conversion_params(X, target_opset, model) |
|
|
|
input_name = 'float_input' |
|
initial_type = [(input_name, tensor)] |
|
convert2onnx(model, initial_type, options, target_opset, args.onnx_filename) |
|
|