import argparse |
import joblib |
from sklearn.datasets import fetch_california_housing, load_diabetes, load_iris |
from skl2onnx import convert_sklearn |
from skl2onnx.common.data_types import FloatTensorType |
def load_dataset(dataset_name): |
if dataset_name == 'california': |
dataset = fetch_california_housing() |
elif dataset_name == 'diabetes': |
dataset = load_diabetes() |
elif dataset_name == 'iris': |
dataset = load_iris() |
else: |
raise ValueError("Invalid dataset name") |
return dataset.data, dataset.target |
def prepare_onnx_conversion_params(X, target_opset, model): |
if target_opset in {9, 17}: |
tensor = FloatTensorType([None, 4]) |
else: |
tensor = FloatTensorType([None, X.shape[1]]) |
if target_opset == 9: |
options = {id(model): {'zipmap': False}} |
else: |
options = None |
return tensor, options |
def convert2onnx(model, initial_type, options, target_opset, onnx_filename): |
try: |
model.save_model(onnx_filename, format="onnx") |
print("Model saved in ONNX format successfully.") |
except Exception as e: |
print("Error occurred while saving model in ONNX format:", e) |
print("Trying a second method...") |
try: |
onnx_model = convert_sklearn(model, initial_types=initial_type, options=options, target_opset=target_opset) |
with open(onnx_filename, "wb") as f: |
f.write(onnx_model.SerializeToString()) |
print("Model converted to ONNX format and saved successfully.") |
except Exception as e: |
print("Error occurred while converting model to ONNX format:", e) |
""" |
python convert2onnx.py california adaboost_regressor.joblib adaboost_regressor.onnx |
""" |
if __name__ == "__main__": |
parser = argparse.ArgumentParser(description='Converts a sklearn model to ONNX format.') |
parser.add_argument('dataset_name', type=str, help='Name of the dataset. Choose from: "california", "diabetes", or "iris".') |
parser.add_argument('model_path', type=str, help='Path to the trained model file.') |
parser.add_argument('onnx_filename', type=str, help='The filename for the output ONNX file.') |
args = parser.parse_args() |
X, y = load_dataset(args.dataset_name) |
model = joblib.load(args.model_path) |
target_opset = 12 |
tensor, options = prepare_onnx_conversion_params(X, target_opset, model) |
input_name = 'float_input' |
initial_type = [(input_name, tensor)] |
convert2onnx(model, initial_type, options, target_opset, args.onnx_filename) |