romanbredehoft-zama commited on
Commit
b0303a0
1 Parent(s): 18ba8c1

Change model to decision tree

Browse files
deployment_files/client.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:859c5895ff365bfd7fa8b592784717d14094c29aaee032c2f23716d225c15855
3
- size 29354
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7bad4947dfc472f67c4ac52c5a26077177b8993ee8b1541ae3fb7c473d94d7fb
3
+ size 28647
deployment_files/server.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3490a405f8634b3c95f6a2d3fb8cf0276e3f49adc4b25911afab7f97524c2f7a
3
- size 2308
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b1e87acc2acda1565b6b23ea82be8d6c6cc4b3747106502f73ebc62397cceaa
3
+ size 1731
development.py CHANGED
@@ -17,7 +17,7 @@ from settings import (
17
  THIRD_PARTY_COLUMNS,
18
  )
19
  from utils.client_server_interface import MultiInputsFHEModelDev
20
- from utils.model import MultiInputXGBClassifier
21
  from utils.pre_processing import get_pre_processors
22
 
23
 
@@ -54,7 +54,7 @@ preprocessed_data_x = numpy.concatenate((preprocessed_data_user, preprocessed_da
54
 
55
  print("\nTrain and compile the model")
56
 
57
- model = MultiInputXGBClassifier(max_depth=3, n_estimators=20)
58
 
59
  model, sklearn_model = model.fit_benchmark(preprocessed_data_x, data_y)
60
 
@@ -62,13 +62,12 @@ multi_inputs_train = get_processed_multi_inputs(preprocessed_data_x)
62
 
63
  model.compile(*multi_inputs_train, inputs_encryption_status=["encrypted", "encrypted", "encrypted"])
64
 
 
 
65
  # Delete the deployment folder and its content if it already exists
66
  if DEPLOYMENT_PATH.is_dir():
67
  shutil.rmtree(DEPLOYMENT_PATH)
68
 
69
-
70
- print("\nSave deployment files")
71
-
72
  # Save files needed for deployment (and enable cross-platform deployment)
73
  fhe_dev = MultiInputsFHEModelDev(DEPLOYMENT_PATH, model)
74
  fhe_dev.save(via_mlir=True)
 
17
  THIRD_PARTY_COLUMNS,
18
  )
19
  from utils.client_server_interface import MultiInputsFHEModelDev
20
+ from utils.model import MultiInputDecisionTreeClassifier
21
  from utils.pre_processing import get_pre_processors
22
 
23
 
 
54
 
55
  print("\nTrain and compile the model")
56
 
57
+ model = MultiInputDecisionTreeClassifier()
58
 
59
  model, sklearn_model = model.fit_benchmark(preprocessed_data_x, data_y)
60
 
 
62
 
63
  model.compile(*multi_inputs_train, inputs_encryption_status=["encrypted", "encrypted", "encrypted"])
64
 
65
+ print("\nSave deployment files")
66
+
67
  # Delete the deployment folder and its content if it already exists
68
  if DEPLOYMENT_PATH.is_dir():
69
  shutil.rmtree(DEPLOYMENT_PATH)
70
 
 
 
 
71
  # Save files needed for deployment (and enable cross-platform deployment)
72
  fhe_dev = MultiInputsFHEModelDev(DEPLOYMENT_PATH, model)
73
  fhe_dev.save(via_mlir=True)
utils/model.py CHANGED
@@ -13,10 +13,9 @@ from concrete.ml.common.utils import (
13
  check_there_is_no_p_error_options_in_configuration
14
  )
15
  from concrete.ml.quantization.quantized_module import QuantizedModule, _get_inputset_generator
16
- from concrete.ml.sklearn import XGBClassifier as ConcreteXGBClassifier
17
 
18
-
19
- class MultiInputXGBClassifier(ConcreteXGBClassifier):
20
 
21
  def quantize_input(self, *X: numpy.ndarray) -> numpy.ndarray:
22
  self.check_model_is_fitted()
@@ -171,3 +170,7 @@ class MultiInputXGBClassifier(ConcreteXGBClassifier):
171
  print(f"FHE execution time per inference: {numpy.mean(execution_times) :.2}s")
172
 
173
  return numpy.array(y_preds)
 
 
 
 
 
13
  check_there_is_no_p_error_options_in_configuration
14
  )
15
  from concrete.ml.quantization.quantized_module import QuantizedModule, _get_inputset_generator
16
+ from concrete.ml.sklearn import DecisionTreeClassifier
17
 
18
+ class MultiInputModel:
 
19
 
20
  def quantize_input(self, *X: numpy.ndarray) -> numpy.ndarray:
21
  self.check_model_is_fitted()
 
170
  print(f"FHE execution time per inference: {numpy.mean(execution_times) :.2}s")
171
 
172
  return numpy.array(y_preds)
173
+
174
+
175
+ class MultiInputDecisionTreeClassifier(MultiInputModel, DecisionTreeClassifier):
176
+ pass