#!/usr/bin/env python3 # -*- coding: utf-8 -*- import argparse import joblib from skl2onnx import convert_sklearn from skl2onnx.common.data_types import FloatTensorType def load_dataset(dataset_name): if dataset_name == 'california': from sklearn.datasets import fetch_california_housing dataset = fetch_california_housing() elif dataset_name == 'diabetes': from sklearn.datasets import load_diabetes dataset = load_diabetes() elif dataset_name == 'iris': from sklearn.datasets import load_iris dataset = load_iris() elif dataset_name == "cardiotocography": from sklearn.datasets import fetch_openml dataset = fetch_openml(name=dataset_name, version=1, as_frame=False) X, y = dataset.data, dataset.target s = y == "3" y = s.astype(int) return X, y else: raise ValueError("Invalid dataset name") return dataset.data, dataset.target def prepare_onnx_conversion_params(X, target_opset, model): if target_opset in {9, 17}: tensor = FloatTensorType([None, 4]) else: tensor = FloatTensorType([None, X.shape[1]]) if target_opset == 9: options = {id(model): {'zipmap': False}} else: options = None return tensor, options def convert2onnx(model, initial_type, options, target_opset, onnx_filename): try: model.save_model(onnx_filename, format="onnx") print("Model saved in ONNX format successfully.") except Exception as e: print("Error occurred while saving model in ONNX format:", e) print("Trying a second method...") try: onnx_model = convert_sklearn(model, initial_types=initial_type, options=options, target_opset=target_opset) with open(onnx_filename, "wb") as f: f.write(onnx_model.SerializeToString()) print("Model converted to ONNX format and saved successfully.") except Exception as e: print("Error occurred while converting model to ONNX format:", e) """ python convert2onnx.py california adaboost_regressor.joblib adaboost_regressor.onnx """ if __name__ == "__main__": parser = argparse.ArgumentParser(description='Converts a sklearn model to ONNX format.') parser.add_argument('dataset_name', type=str, help='Name of the dataset. Choose from: "california", "diabetes", "iris" or "cardiotocography".') parser.add_argument('model_path', type=str, help='Path to the trained model file.') parser.add_argument('onnx_filename', type=str, help='The filename for the output ONNX file.') args = parser.parse_args() X, y = load_dataset(args.dataset_name) model = joblib.load(args.model_path) target_opset = 12 tensor, options = prepare_onnx_conversion_params(X, target_opset, model) input_name = 'float_input' initial_type = [(input_name, tensor)] convert2onnx(model, initial_type, options, target_opset, args.onnx_filename)