File size: 2,980 Bytes
48946cf
 
 
b34166e
 
 
 
 
 
 
 
 
478d418
b34166e
 
478d418
b34166e
 
478d418
b34166e
478d418
 
 
 
 
 
 
b34166e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1e1349
 
b34166e
c1e1349
 
 
 
 
 
 
 
 
b34166e
 
 
 
 
 
478d418
b34166e
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import joblib

from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType


def load_dataset(dataset_name):
    if dataset_name == 'california':
        from sklearn.datasets import fetch_california_housing
        dataset = fetch_california_housing()
    elif dataset_name == 'diabetes':
        from sklearn.datasets import load_diabetes
        dataset = load_diabetes()
    elif dataset_name == 'iris':
        from sklearn.datasets import load_iris
        dataset = load_iris()
    elif dataset_name == "cardiotocography":
        from sklearn.datasets import fetch_openml
        dataset = fetch_openml(name=dataset_name, version=1, as_frame=False)
        X, y = dataset.data, dataset.target
        s = y == "3"
        y = s.astype(int)
        return X, y
    else:
        raise ValueError("Invalid dataset name")
    return dataset.data, dataset.target


def prepare_onnx_conversion_params(X, target_opset, model):
    if target_opset in {9, 17}:
        tensor = FloatTensorType([None, 4])
    else:
        tensor = FloatTensorType([None, X.shape[1]])
    if target_opset == 9:
        options = {id(model): {'zipmap': False}}
    else:
        options = None
    return tensor, options


def convert2onnx(model, initial_type, options, target_opset, onnx_filename):
    try:
        model.save_model(onnx_filename, format="onnx")
        print("Model saved in ONNX format successfully.")
    except Exception as e:
        print("Error occurred while saving model in ONNX format:", e)
        print("Trying a second method...")
        try:
            onnx_model = convert_sklearn(model, initial_types=initial_type, options=options, target_opset=target_opset)
            with open(onnx_filename, "wb") as f:
                f.write(onnx_model.SerializeToString())
            print("Model converted to ONNX format and saved successfully.")
        except Exception as e:
            print("Error occurred while converting model to ONNX format:", e)

"""
python convert2onnx.py california adaboost_regressor.joblib adaboost_regressor.onnx
"""
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Converts a sklearn model to ONNX format.')
    parser.add_argument('dataset_name', type=str, help='Name of the dataset. Choose from: "california", "diabetes", "iris" or "cardiotocography".')
    parser.add_argument('model_path', type=str, help='Path to the trained model file.')
    parser.add_argument('onnx_filename', type=str, help='The filename for the output ONNX file.')
    args = parser.parse_args()

    X, y = load_dataset(args.dataset_name)
    model = joblib.load(args.model_path)
    target_opset = 12
    tensor, options = prepare_onnx_conversion_params(X, target_opset, model)

    input_name = 'float_input'
    initial_type = [(input_name, tensor)]
    convert2onnx(model, initial_type, options, target_opset, args.onnx_filename)