andrewssobral commited on
Commit
478d418
1 Parent(s): b7e9713

Added new dataset

Browse files
Files changed (1) hide show
  1. scikit-learn/convert2onnx.py +11 -2
scikit-learn/convert2onnx.py CHANGED
@@ -4,18 +4,27 @@
4
  import argparse
5
  import joblib
6
 
7
- from sklearn.datasets import fetch_california_housing, load_diabetes, load_iris
8
  from skl2onnx import convert_sklearn
9
  from skl2onnx.common.data_types import FloatTensorType
10
 
11
 
12
  def load_dataset(dataset_name):
13
  if dataset_name == 'california':
 
14
  dataset = fetch_california_housing()
15
  elif dataset_name == 'diabetes':
 
16
  dataset = load_diabetes()
17
  elif dataset_name == 'iris':
 
18
  dataset = load_iris()
 
 
 
 
 
 
 
19
  else:
20
  raise ValueError("Invalid dataset name")
21
  return dataset.data, dataset.target
@@ -53,7 +62,7 @@ python convert2onnx.py california adaboost_regressor.joblib adaboost_regressor.o
53
  """
54
  if __name__ == "__main__":
55
  parser = argparse.ArgumentParser(description='Converts a sklearn model to ONNX format.')
56
- parser.add_argument('dataset_name', type=str, help='Name of the dataset. Choose from: "california", "diabetes", or "iris".')
57
  parser.add_argument('model_path', type=str, help='Path to the trained model file.')
58
  parser.add_argument('onnx_filename', type=str, help='The filename for the output ONNX file.')
59
  args = parser.parse_args()
 
4
  import argparse
5
  import joblib
6
 
 
7
  from skl2onnx import convert_sklearn
8
  from skl2onnx.common.data_types import FloatTensorType
9
 
10
 
11
  def load_dataset(dataset_name):
12
  if dataset_name == 'california':
13
+ from sklearn.datasets import fetch_california_housing
14
  dataset = fetch_california_housing()
15
  elif dataset_name == 'diabetes':
16
+ from sklearn.datasets import load_diabetes
17
  dataset = load_diabetes()
18
  elif dataset_name == 'iris':
19
+ from sklearn.datasets import load_iris
20
  dataset = load_iris()
21
+ elif dataset_name == "cardiotocography":
22
+ from sklearn.datasets import fetch_openml
23
+ dataset = fetch_openml(name=dataset_name, version=1, as_frame=False)
24
+ X, y = dataset.data, dataset.target
25
+ s = y == "3"
26
+ y = s.astype(int)
27
+ return X, y
28
  else:
29
  raise ValueError("Invalid dataset name")
30
  return dataset.data, dataset.target
 
62
  """
63
  if __name__ == "__main__":
64
  parser = argparse.ArgumentParser(description='Converts a sklearn model to ONNX format.')
65
+ parser.add_argument('dataset_name', type=str, help='Name of the dataset. Choose from: "california", "diabetes", "iris" or "cardiotocography".')
66
  parser.add_argument('model_path', type=str, help='Path to the trained model file.')
67
  parser.add_argument('onnx_filename', type=str, help='The filename for the output ONNX file.')
68
  args = parser.parse_args()