Commit
•
478d418
1
Parent(s):
b7e9713
Added new dataset
Browse files- scikit-learn/convert2onnx.py +11 -2
scikit-learn/convert2onnx.py
CHANGED
@@ -4,18 +4,27 @@
|
|
4 |
import argparse
|
5 |
import joblib
|
6 |
|
7 |
-
from sklearn.datasets import fetch_california_housing, load_diabetes, load_iris
|
8 |
from skl2onnx import convert_sklearn
|
9 |
from skl2onnx.common.data_types import FloatTensorType
|
10 |
|
11 |
|
12 |
def load_dataset(dataset_name):
|
13 |
if dataset_name == 'california':
|
|
|
14 |
dataset = fetch_california_housing()
|
15 |
elif dataset_name == 'diabetes':
|
|
|
16 |
dataset = load_diabetes()
|
17 |
elif dataset_name == 'iris':
|
|
|
18 |
dataset = load_iris()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
else:
|
20 |
raise ValueError("Invalid dataset name")
|
21 |
return dataset.data, dataset.target
|
@@ -53,7 +62,7 @@ python convert2onnx.py california adaboost_regressor.joblib adaboost_regressor.o
|
|
53 |
"""
|
54 |
if __name__ == "__main__":
|
55 |
parser = argparse.ArgumentParser(description='Converts a sklearn model to ONNX format.')
|
56 |
-
parser.add_argument('dataset_name', type=str, help='Name of the dataset. Choose from: "california", "diabetes", or "
|
57 |
parser.add_argument('model_path', type=str, help='Path to the trained model file.')
|
58 |
parser.add_argument('onnx_filename', type=str, help='The filename for the output ONNX file.')
|
59 |
args = parser.parse_args()
|
|
|
4 |
import argparse
|
5 |
import joblib
|
6 |
|
|
|
7 |
from skl2onnx import convert_sklearn
|
8 |
from skl2onnx.common.data_types import FloatTensorType
|
9 |
|
10 |
|
11 |
def load_dataset(dataset_name):
|
12 |
if dataset_name == 'california':
|
13 |
+
from sklearn.datasets import fetch_california_housing
|
14 |
dataset = fetch_california_housing()
|
15 |
elif dataset_name == 'diabetes':
|
16 |
+
from sklearn.datasets import load_diabetes
|
17 |
dataset = load_diabetes()
|
18 |
elif dataset_name == 'iris':
|
19 |
+
from sklearn.datasets import load_iris
|
20 |
dataset = load_iris()
|
21 |
+
elif dataset_name == "cardiotocography":
|
22 |
+
from sklearn.datasets import fetch_openml
|
23 |
+
dataset = fetch_openml(name=dataset_name, version=1, as_frame=False)
|
24 |
+
X, y = dataset.data, dataset.target
|
25 |
+
s = y == "3"
|
26 |
+
y = s.astype(int)
|
27 |
+
return X, y
|
28 |
else:
|
29 |
raise ValueError("Invalid dataset name")
|
30 |
return dataset.data, dataset.target
|
|
|
62 |
"""
|
63 |
if __name__ == "__main__":
|
64 |
parser = argparse.ArgumentParser(description='Converts a sklearn model to ONNX format.')
|
65 |
+
parser.add_argument('dataset_name', type=str, help='Name of the dataset. Choose from: "california", "diabetes", "iris" or "cardiotocography".')
|
66 |
parser.add_argument('model_path', type=str, help='Path to the trained model file.')
|
67 |
parser.add_argument('onnx_filename', type=str, help='The filename for the output ONNX file.')
|
68 |
args = parser.parse_args()
|