romanbredehoft-zama commited on
Commit
74c0c8e
β€’
1 Parent(s): b0303a0

Add second model for optional explainability step

Browse files
app.py CHANGED
@@ -26,6 +26,7 @@ from backend import (
26
  run_fhe,
27
  get_output,
28
  decrypt_output,
 
29
  )
30
 
31
 
@@ -60,6 +61,12 @@ with demo:
60
  )
61
  client_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False)
62
 
 
 
 
 
 
 
63
  gr.Markdown("## Step 2: Fill in some information.")
64
  gr.Markdown(
65
  """
@@ -125,6 +132,31 @@ with demo:
125
  label="Encrypted input representation:", max_lines=2, interactive=False
126
  )
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  gr.Markdown("# Server side")
129
  gr.Markdown(
130
  """
@@ -142,6 +174,9 @@ with demo:
142
  label="Total FHE execution time (in seconds):", max_lines=1, interactive=False
143
  )
144
 
 
 
 
145
  gr.Markdown("# Client side")
146
  gr.Markdown(
147
  """
@@ -161,6 +196,13 @@ with demo:
161
  label="Encrypted output representation: ", max_lines=2, interactive=False
162
  )
163
 
 
 
 
 
 
 
 
164
  gr.Markdown("## Step 6: Decrypt the output.")
165
  gr.Markdown(
166
  """
@@ -173,52 +215,39 @@ with demo:
173
  label="Prediction", max_lines=1, interactive=False
174
  )
175
 
176
- # Button generate the keys
177
- keygen_button.click(
178
- keygen_send,
179
- outputs=[client_id, evaluation_key, keygen_button],
180
- )
181
-
182
- # Button to pre-process, generate the key, encrypt and send the user inputs from the client
183
- # side to the server
184
- encrypt_button_user.click(
185
- pre_process_encrypt_send_user,
186
- inputs=[client_id, bool_inputs, num_children, household_size, total_income, age, \
187
- income_type, education_type, family_status, occupation_type, housing_type],
188
- outputs=[encrypted_input_user],
189
  )
190
 
191
- # Button to pre-process, generate the key, encrypt and send the bank inputs from the client
192
- # side to the server
193
- encrypt_button_bank.click(
194
- pre_process_encrypt_send_bank,
195
- inputs=[client_id, account_age],
196
- outputs=[encrypted_input_bank],
 
 
 
 
 
197
  )
198
-
199
- # Button to pre-process, generate the key, encrypt and send the third party inputs from the
200
- # client side to the server
201
- encrypt_button_third_party.click(
202
- pre_process_encrypt_send_third_party,
203
- inputs=[client_id, employed, years_employed],
204
- outputs=[encrypted_input_third_party],
205
  )
206
-
207
- # Button to send the encodings to the server using post method
208
- execute_fhe_button.click(run_fhe, inputs=[client_id], outputs=[fhe_execution_time])
209
-
210
- # Button to send the encodings to the server using post method
211
- get_output_button.click(
212
- get_output,
213
- inputs=[client_id],
214
- outputs=[encrypted_output_representation],
215
  )
216
 
217
- # Button to decrypt the output
218
- decrypt_button.click(
219
- decrypt_output,
220
- inputs=[client_id],
221
- outputs=[prediction_output],
 
 
222
  )
223
 
224
  gr.Markdown(
 
26
  run_fhe,
27
  get_output,
28
  decrypt_output,
29
+ years_employed_encrypt_run_decrypt,
30
  )
31
 
32
 
 
61
  )
62
  client_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False)
63
 
64
+ # Button generate the keys
65
+ keygen_button.click(
66
+ keygen_send,
67
+ outputs=[client_id, evaluation_key, keygen_button],
68
+ )
69
+
70
  gr.Markdown("## Step 2: Fill in some information.")
71
  gr.Markdown(
72
  """
 
132
  label="Encrypted input representation:", max_lines=2, interactive=False
133
  )
134
 
135
+ # Button to pre-process, generate the key, encrypt and send the user inputs from the client
136
+ # side to the server
137
+ encrypt_button_user.click(
138
+ pre_process_encrypt_send_user,
139
+ inputs=[client_id, bool_inputs, num_children, household_size, total_income, age, \
140
+ income_type, education_type, family_status, occupation_type, housing_type],
141
+ outputs=[encrypted_input_user],
142
+ )
143
+
144
+ # Button to pre-process, generate the key, encrypt and send the bank inputs from the client
145
+ # side to the server
146
+ encrypt_button_bank.click(
147
+ pre_process_encrypt_send_bank,
148
+ inputs=[client_id, account_age],
149
+ outputs=[encrypted_input_bank],
150
+ )
151
+
152
+ # Button to pre-process, generate the key, encrypt and send the third party inputs from the
153
+ # client side to the server
154
+ encrypt_button_third_party.click(
155
+ pre_process_encrypt_send_third_party,
156
+ inputs=[client_id, employed, years_employed],
157
+ outputs=[encrypted_input_third_party],
158
+ )
159
+
160
  gr.Markdown("# Server side")
161
  gr.Markdown(
162
  """
 
174
  label="Total FHE execution time (in seconds):", max_lines=1, interactive=False
175
  )
176
 
177
+ # Button to send the encodings to the server using post method
178
+ execute_fhe_button.click(run_fhe, inputs=[client_id], outputs=[fhe_execution_time])
179
+
180
  gr.Markdown("# Client side")
181
  gr.Markdown(
182
  """
 
196
  label="Encrypted output representation: ", max_lines=2, interactive=False
197
  )
198
 
199
+ # Button to send the encodings to the server using post method
200
+ get_output_button.click(
201
+ get_output,
202
+ inputs=[client_id],
203
+ outputs=[encrypted_output_representation],
204
+ )
205
+
206
  gr.Markdown("## Step 6: Decrypt the output.")
207
  gr.Markdown(
208
  """
 
215
  label="Prediction", max_lines=1, interactive=False
216
  )
217
 
218
+ # Button to decrypt the output
219
+ decrypt_button.click(
220
+ decrypt_output,
221
+ inputs=[client_id],
222
+ outputs=[prediction_output],
 
 
 
 
 
 
 
 
223
  )
224
 
225
+ gr.Markdown("## Step 7 (optional): Explain the prediction.")
226
+ gr.Markdown(
227
+ """
228
+ In case the credit card is likely to be denied, the user can run a second model in order to
229
+ Explain the prediction better. More specifically, this new model indicates the number of
230
+ additional years of employment that could be required in order to increase the chance of
231
+ credit card approval.
232
+ All of the above steps are combined into a single button for simplicity. The following
233
+ button therefore encrypts the same inputs (except the years of employment) from all three
234
+ parties, runs the new prediction in FHE and decrypts the output.
235
+ """
236
  )
237
+ years_employed_prediction_button = gr.Button(
238
+ "Encrypt the inputs, compute in FHE and decrypt the output."
 
 
 
 
 
239
  )
240
+ years_employed_prediction = gr.Textbox(
241
+ label="Additional years of employed required.", max_lines=1, interactive=False
 
 
 
 
 
 
 
242
  )
243
 
244
+ # Button to explain the prediction
245
+ years_employed_prediction_button.click(
246
+ years_employed_encrypt_run_decrypt,
247
+ inputs=[client_id, prediction_output, bool_inputs, num_children, household_size, \
248
+ total_income, age, income_type, education_type, family_status, occupation_type, \
249
+ housing_type, account_age, employed, years_employed],
250
+ outputs=[years_employed_prediction],
251
  )
252
 
253
  gr.Markdown(
backend.py CHANGED
@@ -14,20 +14,26 @@ from settings import (
14
  FHE_KEYS,
15
  CLIENT_FILES,
16
  SERVER_FILES,
17
- DEPLOYMENT_PATH,
18
- PROCESSED_INPUT_SHAPE,
 
 
19
  INPUT_INDEXES,
20
- INPUT_SLICES,
 
21
  PRE_PROCESSOR_USER_PATH,
22
  PRE_PROCESSOR_BANK_PATH,
23
  PRE_PROCESSOR_THIRD_PARTY_PATH,
24
  CLIENT_TYPES,
25
  USER_COLUMNS,
26
  BANK_COLUMNS,
27
- THIRD_PARTY_COLUMNS,
28
  )
29
 
30
- from utils.client_server_interface import MultiInputsFHEModelClient
 
 
 
31
 
32
  # Load pre-processor instances
33
  with (
@@ -87,18 +93,22 @@ def clean_temporary_files(n_keys=20):
87
  shutil.rmtree(directory)
88
 
89
 
90
- def _get_client(client_id):
91
  """Get the client instance.
92
 
93
  Args:
94
  client_id (int): The client ID to consider.
 
 
95
 
96
  Returns:
97
  FHEModelClient: The client instance.
98
  """
99
- key_dir = FHE_KEYS / f"{client_id}"
 
 
100
 
101
- return MultiInputsFHEModelClient(DEPLOYMENT_PATH, key_dir=key_dir, nb_inputs=len(CLIENT_TYPES))
102
 
103
 
104
  def _get_client_file_path(name, client_id, client_type=None):
@@ -196,7 +206,7 @@ def keygen_send():
196
  return client_id, evaluation_key_short, gr.update(value="Keys are generated and evaluation key is sent βœ…")
197
 
198
 
199
- def _encrypt_send(client_id, inputs, client_type):
200
  """Encrypt the given inputs for a specific client and send it to the server.
201
 
202
  Args:
@@ -205,8 +215,7 @@ def _encrypt_send(client_id, inputs, client_type):
205
  client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
206
 
207
  Returns:
208
- client_id, encrypted_inputs_short (int, bytes): Integer ID representing the current client
209
- and a byte short representation of the encrypted input to send.
210
  """
211
  if client_id == "":
212
  raise gr.Error("Please generate the keys first.")
@@ -218,8 +227,8 @@ def _encrypt_send(client_id, inputs, client_type):
218
  encrypted_inputs = client.quantize_encrypt_serialize_multi_inputs(
219
  inputs,
220
  input_index=INPUT_INDEXES[client_type],
221
- processed_input_shape=PROCESSED_INPUT_SHAPE,
222
- input_slice=INPUT_SLICES[client_type],
223
  )
224
 
225
  file_name = "encrypted_inputs"
@@ -239,16 +248,14 @@ def _encrypt_send(client_id, inputs, client_type):
239
  return encrypted_inputs_short
240
 
241
 
242
- def pre_process_encrypt_send_user(client_id, *inputs):
243
- """Pre-process, encrypt and send the user inputs for a specific client to the server.
244
 
245
  Args:
246
- client_id (str): The current client ID to consider.
247
  *inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
248
 
249
  Returns:
250
- (int, bytes): Integer ID representing the current client and a byte short representation of
251
- the encrypted input to send.
252
  """
253
  bool_inputs, num_children, household_size, total_income, age, income_type, education_type, \
254
  family_status, occupation_type, housing_type = inputs
@@ -277,19 +284,32 @@ def pre_process_encrypt_send_user(client_id, *inputs):
277
 
278
  preprocessed_user_inputs = PRE_PROCESSOR_USER.transform(user_inputs)
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  return _encrypt_send(client_id, preprocessed_user_inputs, "user")
281
 
282
 
283
- def pre_process_encrypt_send_bank(client_id, *inputs):
284
- """Pre-process, encrypt and send the bank inputs for a specific client to the server.
285
 
286
  Args:
287
- client_id (str): The current client ID to consider.
288
  *inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
289
 
290
  Returns:
291
- (int, bytes): Integer ID representing the current client and a byte short representation of
292
- the encrypted input to send.
293
  """
294
  account_age = inputs[0]
295
 
@@ -301,32 +321,65 @@ def pre_process_encrypt_send_bank(client_id, *inputs):
301
 
302
  preprocessed_bank_inputs = PRE_PROCESSOR_BANK.transform(bank_inputs)
303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  return _encrypt_send(client_id, preprocessed_bank_inputs, "bank")
305
 
306
 
307
- def pre_process_encrypt_send_third_party(client_id, *inputs):
308
- """Pre-process, encrypt and send the third party inputs for a specific client to the server.
309
 
310
  Args:
311
- client_id (str): The current client ID to consider.
312
  *inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
313
 
314
  Returns:
315
- (int, bytes): Integer ID representing the current client and a byte short representation of
316
- the encrypted input to send.
317
  """
318
- employed, years_employed = inputs
 
 
 
 
 
319
 
320
  is_employed = employed == "Yes"
 
321
 
322
- third_party_inputs = pandas.DataFrame({
323
- "Employed": [is_employed],
324
- "Years_employed": [years_employed],
325
- })
 
 
 
 
 
 
 
 
 
326
 
327
- third_party_inputs = third_party_inputs.reindex(THIRD_PARTY_COLUMNS, axis=1)
 
 
328
 
329
- preprocessed_third_party_inputs = PRE_PROCESSOR_THIRD_PARTY.transform(third_party_inputs)
 
 
 
330
 
331
  return _encrypt_send(client_id, preprocessed_third_party_inputs, "third_party")
332
 
@@ -430,4 +483,86 @@ def decrypt_output(client_id):
430
  # Determine the predicted class
431
  output = numpy.argmax(output_proba, axis=1).squeeze()
432
 
433
- return "Credit card is likely to be approved βœ…" if output == 1 else "Credit card is likely to be denied ❌"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  FHE_KEYS,
15
  CLIENT_FILES,
16
  SERVER_FILES,
17
+ APPROVAL_DEPLOYMENT_PATH,
18
+ EXPLAIN_DEPLOYMENT_PATH,
19
+ APPROVAL_PROCESSED_INPUT_SHAPE,
20
+ EXPLAIN_PROCESSED_INPUT_SHAPE,
21
  INPUT_INDEXES,
22
+ APPROVAL_INPUT_SLICES,
23
+ EXPLAIN_INPUT_SLICES,
24
  PRE_PROCESSOR_USER_PATH,
25
  PRE_PROCESSOR_BANK_PATH,
26
  PRE_PROCESSOR_THIRD_PARTY_PATH,
27
  CLIENT_TYPES,
28
  USER_COLUMNS,
29
  BANK_COLUMNS,
30
+ APPROVAL_THIRD_PARTY_COLUMNS,
31
  )
32
 
33
+ from utils.client_server_interface import MultiInputsFHEModelClient, MultiInputsFHEModelServer
34
+
35
+ # Load the server used for explaining the prediction
36
+ EXPLAIN_FHE_SERVER = MultiInputsFHEModelServer(EXPLAIN_DEPLOYMENT_PATH)
37
 
38
  # Load pre-processor instances
39
  with (
 
93
  shutil.rmtree(directory)
94
 
95
 
96
+ def _get_client(client_id, is_approval=True):
97
  """Get the client instance.
98
 
99
  Args:
100
  client_id (int): The client ID to consider.
101
+ is_approval (bool): If client is representing the 'approval' model (else, it is
102
+ representing the 'explain' model). Default to True.
103
 
104
  Returns:
105
  FHEModelClient: The client instance.
106
  """
107
+ key_suffix = "approval" if is_approval else "explain"
108
+ key_dir = FHE_KEYS / f"{client_id}_{key_suffix}"
109
+ client_dir = APPROVAL_DEPLOYMENT_PATH if is_approval else EXPLAIN_DEPLOYMENT_PATH
110
 
111
+ return MultiInputsFHEModelClient(client_dir, key_dir=key_dir, nb_inputs=len(CLIENT_TYPES))
112
 
113
 
114
  def _get_client_file_path(name, client_id, client_type=None):
 
206
  return client_id, evaluation_key_short, gr.update(value="Keys are generated and evaluation key is sent βœ…")
207
 
208
 
209
+ def _encrypt_send(client_id, inputs, client_type, app_mode=True):
210
  """Encrypt the given inputs for a specific client and send it to the server.
211
 
212
  Args:
 
215
  client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
216
 
217
  Returns:
218
+ encrypted_inputs_short (str): A short representation of the encrypted input to send in hex.
 
219
  """
220
  if client_id == "":
221
  raise gr.Error("Please generate the keys first.")
 
227
  encrypted_inputs = client.quantize_encrypt_serialize_multi_inputs(
228
  inputs,
229
  input_index=INPUT_INDEXES[client_type],
230
+ processed_input_shape=APPROVAL_PROCESSED_INPUT_SHAPE,
231
+ input_slice=APPROVAL_INPUT_SLICES[client_type],
232
  )
233
 
234
  file_name = "encrypted_inputs"
 
248
  return encrypted_inputs_short
249
 
250
 
251
+ def _pre_process_user(*inputs):
252
+ """Pre-process the user inputs.
253
 
254
  Args:
 
255
  *inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
256
 
257
  Returns:
258
+ (numpy.ndarray): The pre-processed inputs.
 
259
  """
260
  bool_inputs, num_children, household_size, total_income, age, income_type, education_type, \
261
  family_status, occupation_type, housing_type = inputs
 
284
 
285
  preprocessed_user_inputs = PRE_PROCESSOR_USER.transform(user_inputs)
286
 
287
+ return preprocessed_user_inputs
288
+
289
+
290
+ def pre_process_encrypt_send_user(client_id, *inputs):
291
+ """Pre-process, encrypt and send the user inputs for a specific client to the server.
292
+
293
+ Args:
294
+ client_id (str): The current client ID to consider.
295
+ *inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
296
+
297
+ Returns:
298
+ (str): A short representation of the encrypted input to send in hex.
299
+ """
300
+ preprocessed_user_inputs = _pre_process_user(*inputs)
301
+
302
  return _encrypt_send(client_id, preprocessed_user_inputs, "user")
303
 
304
 
305
+ def _pre_process_bank(*inputs):
306
+ """Pre-process the bank inputs.
307
 
308
  Args:
 
309
  *inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
310
 
311
  Returns:
312
+ (numpy.ndarray): The pre-processed inputs.
 
313
  """
314
  account_age = inputs[0]
315
 
 
321
 
322
  preprocessed_bank_inputs = PRE_PROCESSOR_BANK.transform(bank_inputs)
323
 
324
+ return preprocessed_bank_inputs
325
+
326
+
327
+ def pre_process_encrypt_send_bank(client_id, *inputs):
328
+ """Pre-process, encrypt and send the bank inputs for a specific client to the server.
329
+
330
+ Args:
331
+ client_id (str): The current client ID to consider.
332
+ *inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
333
+
334
+ Returns:
335
+ (str): A short representation of the encrypted input to send in hex.
336
+ """
337
+ preprocessed_bank_inputs = _pre_process_bank(*inputs)
338
+
339
  return _encrypt_send(client_id, preprocessed_bank_inputs, "bank")
340
 
341
 
342
+ def _pre_process_third_party(*inputs):
343
+ """Pre-process the third party inputs.
344
 
345
  Args:
 
346
  *inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
347
 
348
  Returns:
349
+ (numpy.ndarray): The pre-processed inputs.
 
350
  """
351
+ third_party_data = {}
352
+ if len(inputs) == 1:
353
+ employed = inputs[0]
354
+ else:
355
+ employed, years_employed = inputs
356
+ third_party_data["Years_employed"] = [years_employed]
357
 
358
  is_employed = employed == "Yes"
359
+ third_party_data["Employed"] = [is_employed]
360
 
361
+ third_party_inputs = pandas.DataFrame(third_party_data)
362
+
363
+ if len(inputs) == 1:
364
+ preprocessed_third_party_inputs = third_party_inputs.to_numpy()
365
+ else:
366
+ third_party_inputs = third_party_inputs.reindex(APPROVAL_THIRD_PARTY_COLUMNS, axis=1)
367
+ preprocessed_third_party_inputs = PRE_PROCESSOR_THIRD_PARTY.transform(third_party_inputs)
368
+
369
+ return preprocessed_third_party_inputs
370
+
371
+
372
+ def pre_process_encrypt_send_third_party(client_id, *inputs):
373
+ """Pre-process, encrypt and send the third party inputs for a specific client to the server.
374
 
375
+ Args:
376
+ client_id (str): The current client ID to consider.
377
+ *inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
378
 
379
+ Returns:
380
+ (str): A short representation of the encrypted input to send in hex.
381
+ """
382
+ preprocessed_third_party_inputs = _pre_process_third_party(*inputs)
383
 
384
  return _encrypt_send(client_id, preprocessed_third_party_inputs, "third_party")
385
 
 
483
  # Determine the predicted class
484
  output = numpy.argmax(output_proba, axis=1).squeeze()
485
 
486
+ return "Credit card is likely to be approved βœ…" if output == 1 else "Credit card is likely to be denied ❌"
487
+
488
+
489
+ def years_employed_encrypt_run_decrypt(client_id, prediction_output, *inputs):
490
+ """Pre-process and encrypt the inputs, run the prediction in FHE and decrypt the output.
491
+
492
+ Args:
493
+ client_id (str): The current client ID to consider.
494
+ prediction_output (str): The initial prediction output. This parameter is only used to
495
+ throw an error in case the prediction was positive.
496
+ *inputs (Tuple[numpy.ndarray]): The inputs to consider.
497
+
498
+ Returns:
499
+ (str): A message indicating the number of additional years of employment that could be
500
+ required in order to increase the chance of
501
+ credit card approval.
502
+ """
503
+
504
+ if "approved" in prediction_output:
505
+ raise gr.Error(
506
+ "Explaining the prediction can only be done if the credit card is likely to be denied."
507
+ )
508
+
509
+ # Retrieve the client instance
510
+ client = _get_client(client_id, is_approval=False)
511
+
512
+ # Generate the private and evaluation keys
513
+ client.generate_private_and_evaluation_keys(force=False)
514
+
515
+ # Retrieve the serialized evaluation key
516
+ evaluation_key = client.get_serialized_evaluation_keys()
517
+
518
+ bool_inputs, num_children, household_size, total_income, age, income_type, education_type, \
519
+ family_status, occupation_type, housing_type, account_age, employed, years_employed = inputs
520
+
521
+ preprocessed_user_inputs = _pre_process_user(
522
+ bool_inputs, num_children, household_size, total_income, age, income_type, education_type,
523
+ family_status, occupation_type, housing_type,
524
+ )
525
+ preprocessed_bank_inputs = _pre_process_bank(account_age)
526
+ preprocessed_third_party_inputs = _pre_process_third_party(employed)
527
+
528
+ preprocessed_inputs = [
529
+ preprocessed_user_inputs,
530
+ preprocessed_bank_inputs,
531
+ preprocessed_third_party_inputs
532
+ ]
533
+
534
+ # Quantize, encrypt and serialize the inputs
535
+ encrypted_inputs = []
536
+ for client_type, preprocessed_input in zip(CLIENT_TYPES, preprocessed_inputs):
537
+ encrypted_input = client.quantize_encrypt_serialize_multi_inputs(
538
+ preprocessed_input,
539
+ input_index=INPUT_INDEXES[client_type],
540
+ processed_input_shape=EXPLAIN_PROCESSED_INPUT_SHAPE,
541
+ input_slice=EXPLAIN_INPUT_SLICES[client_type],
542
+ )
543
+ encrypted_inputs.append(encrypted_input)
544
+
545
+ # Run the FHE computation
546
+ encrypted_output = EXPLAIN_FHE_SERVER.run(
547
+ *encrypted_inputs,
548
+ serialized_evaluation_keys=evaluation_key
549
+ )
550
+
551
+ # Decrypt the output
552
+ output_prediction = client.deserialize_decrypt_dequantize(encrypted_output)
553
+
554
+ # Get the difference with the initial 'years of employment' input
555
+ years_employed_diff = int(numpy.ceil(output_prediction.squeeze() - years_employed))
556
+
557
+ if years_employed_diff > 0:
558
+ return (
559
+ f"Having at least {years_employed_diff} more years of employment would increase "
560
+ "your chance of having your credit card approved."
561
+ )
562
+
563
+ return (
564
+ "The number of years of employment you provided seems to be enough. The negative prediction "
565
+ "might come from other inputs."
566
+ )
567
+
568
+
deployment_files/{client.zip β†’ approval_model/client.zip} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7bad4947dfc472f67c4ac52c5a26077177b8993ee8b1541ae3fb7c473d94d7fb
3
- size 28647
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2ceb4a6e07cd13471c8c8c963d9e4de52d5af624e81775ebeb2421e29b9ba8c
3
+ size 28667
deployment_files/{server.zip β†’ approval_model/server.zip} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0b1e87acc2acda1565b6b23ea82be8d6c6cc4b3747106502f73ebc62397cceaa
3
- size 1731
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e724012427c90fdc8df14360942909e5fa0accc8b27584880baab2a91533e78
3
+ size 1729
deployment_files/{versions.json β†’ approval_model/versions.json} RENAMED
File without changes
deployment_files/explain_model/client.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:506276661b4612d664d59f0d90aac1b5c09f942a850ec189aa16204d54433b27
3
+ size 27714
deployment_files/explain_model/server.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:596ae66c7effd9733a8780088984d4fc08479d67c11586ee5787111329cb353f
3
+ size 2035
deployment_files/explain_model/versions.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"concrete-python": "2.5.0rc1", "concrete-ml": "1.3.0", "python": "3.10.11"}
development.py CHANGED
@@ -6,28 +6,49 @@ import pandas
6
  import pickle
7
 
8
  from settings import (
9
- DEPLOYMENT_PATH,
 
10
  DATA_PATH,
11
- INPUT_SLICES,
 
12
  PRE_PROCESSOR_USER_PATH,
13
  PRE_PROCESSOR_BANK_PATH,
14
  PRE_PROCESSOR_THIRD_PARTY_PATH,
15
  USER_COLUMNS,
16
  BANK_COLUMNS,
17
- THIRD_PARTY_COLUMNS,
 
18
  )
19
  from utils.client_server_interface import MultiInputsFHEModelDev
20
- from utils.model import MultiInputDecisionTreeClassifier
21
  from utils.pre_processing import get_pre_processors
22
 
23
 
24
- def get_processed_multi_inputs(data):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  return (
26
- data[:, INPUT_SLICES["user"]],
27
- data[:, INPUT_SLICES["bank"]],
28
- data[:, INPUT_SLICES["third_party"]]
29
  )
30
 
 
31
  print("Load and pre-process the data")
32
 
33
  # Load the data
@@ -40,7 +61,7 @@ data_y = data_x.pop("Target").copy().to_frame()
40
  # Get data from all parties
41
  data_user = data_x[USER_COLUMNS].copy()
42
  data_bank = data_x[BANK_COLUMNS].copy()
43
- data_third_party = data_x[THIRD_PARTY_COLUMNS].copy()
44
 
45
  # Feature engineer the data
46
  pre_processor_user, pre_processor_bank, pre_processor_third_party = get_pre_processors()
@@ -54,23 +75,23 @@ preprocessed_data_x = numpy.concatenate((preprocessed_data_user, preprocessed_da
54
 
55
  print("\nTrain and compile the model")
56
 
57
- model = MultiInputDecisionTreeClassifier()
58
 
59
- model, sklearn_model = model.fit_benchmark(preprocessed_data_x, data_y)
60
 
61
- multi_inputs_train = get_processed_multi_inputs(preprocessed_data_x)
62
 
63
- model.compile(*multi_inputs_train, inputs_encryption_status=["encrypted", "encrypted", "encrypted"])
64
 
65
  print("\nSave deployment files")
66
 
67
  # Delete the deployment folder and its content if it already exists
68
- if DEPLOYMENT_PATH.is_dir():
69
- shutil.rmtree(DEPLOYMENT_PATH)
70
 
71
  # Save files needed for deployment (and enable cross-platform deployment)
72
- fhe_dev = MultiInputsFHEModelDev(DEPLOYMENT_PATH, model)
73
- fhe_dev.save(via_mlir=True)
74
 
75
  # Save pre-processors
76
  with (
@@ -82,4 +103,44 @@ with (
82
  pickle.dump(pre_processor_bank, file_bank)
83
  pickle.dump(pre_processor_third_party, file_third_party)
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  print("\nDone !")
 
6
  import pickle
7
 
8
  from settings import (
9
+ APPROVAL_DEPLOYMENT_PATH,
10
+ EXPLAIN_DEPLOYMENT_PATH,
11
  DATA_PATH,
12
+ APPROVAL_INPUT_SLICES,
13
+ EXPLAIN_INPUT_SLICES,
14
  PRE_PROCESSOR_USER_PATH,
15
  PRE_PROCESSOR_BANK_PATH,
16
  PRE_PROCESSOR_THIRD_PARTY_PATH,
17
  USER_COLUMNS,
18
  BANK_COLUMNS,
19
+ APPROVAL_THIRD_PARTY_COLUMNS,
20
+ EXPLAIN_THIRD_PARTY_COLUMNS,
21
  )
22
  from utils.client_server_interface import MultiInputsFHEModelDev
23
+ from utils.model import MultiInputDecisionTreeClassifier, MultiInputDecisionTreeRegressor
24
  from utils.pre_processing import get_pre_processors
25
 
26
 
27
+ def get_multi_inputs(data, is_approval):
28
+ """Get inputs for all three parties from the input data, using fixed slices.
29
+
30
+ Args:
31
+ data (numpy.ndarray): The input data to consider.
32
+ is_approval (bool): If the data should be used for the 'approval' model (else, otherwise for
33
+ the 'explain' model).
34
+
35
+ Returns:
36
+ (Tuple[numpy.ndarray]): The inputs for all three parties.
37
+ """
38
+ if is_approval:
39
+ return (
40
+ data[:, APPROVAL_INPUT_SLICES["user"]],
41
+ data[:, APPROVAL_INPUT_SLICES["bank"]],
42
+ data[:, APPROVAL_INPUT_SLICES["third_party"]]
43
+ )
44
+
45
  return (
46
+ data[:, EXPLAIN_INPUT_SLICES["user"]],
47
+ data[:, EXPLAIN_INPUT_SLICES["bank"]],
48
+ data[:, EXPLAIN_INPUT_SLICES["third_party"]]
49
  )
50
 
51
+
52
  print("Load and pre-process the data")
53
 
54
  # Load the data
 
61
  # Get data from all parties
62
  data_user = data_x[USER_COLUMNS].copy()
63
  data_bank = data_x[BANK_COLUMNS].copy()
64
+ data_third_party = data_x[APPROVAL_THIRD_PARTY_COLUMNS].copy()
65
 
66
  # Feature engineer the data
67
  pre_processor_user, pre_processor_bank, pre_processor_third_party = get_pre_processors()
 
75
 
76
  print("\nTrain and compile the model")
77
 
78
+ model_approval = MultiInputDecisionTreeClassifier()
79
 
80
+ model_approval, sklearn_model_approval = model_approval.fit_benchmark(preprocessed_data_x, data_y)
81
 
82
+ multi_inputs_train = get_multi_inputs(preprocessed_data_x, is_approval=True)
83
 
84
+ model_approval.compile(*multi_inputs_train, inputs_encryption_status=["encrypted", "encrypted", "encrypted"])
85
 
86
  print("\nSave deployment files")
87
 
88
  # Delete the deployment folder and its content if it already exists
89
+ if APPROVAL_DEPLOYMENT_PATH.is_dir():
90
+ shutil.rmtree(APPROVAL_DEPLOYMENT_PATH)
91
 
92
  # Save files needed for deployment (and enable cross-platform deployment)
93
+ fhe_model_dev_approval = MultiInputsFHEModelDev(APPROVAL_DEPLOYMENT_PATH, model_approval)
94
+ fhe_model_dev_approval.save(via_mlir=True)
95
 
96
  # Save pre-processors
97
  with (
 
103
  pickle.dump(pre_processor_bank, file_bank)
104
  pickle.dump(pre_processor_third_party, file_third_party)
105
 
106
+
107
+ print("\nLoad, train, compile and save files for the 'explain' model")
108
+
109
+ # Define input and target data
110
+ data_x = data.copy()
111
+ data_y = data_x.pop("Years_employed").copy().to_frame()
112
+ target_values = data_x.pop("Target").copy()
113
+
114
+ # Get all data points whose target value is True (credit card has been approved)
115
+ approved_mask = target_values == 1
116
+ data_x_approved = data_x[approved_mask]
117
+ data_y_approved = data_y[approved_mask]
118
+
119
+ # Get data from all parties
120
+ data_user = data_x_approved[USER_COLUMNS].copy()
121
+ data_bank = data_x_approved[BANK_COLUMNS].copy()
122
+ data_third_party = data_x_approved[EXPLAIN_THIRD_PARTY_COLUMNS].copy()
123
+
124
+ preprocessed_data_user = pre_processor_user.transform(data_user)
125
+ preprocessed_data_bank = pre_processor_bank.transform(data_bank)
126
+ preprocessed_data_third_party = data_third_party.to_numpy()
127
+
128
+ preprocessed_data_x = numpy.concatenate((preprocessed_data_user, preprocessed_data_bank, preprocessed_data_third_party), axis=1)
129
+
130
+ model_explain = MultiInputDecisionTreeRegressor()
131
+
132
+ model_explain, sklearn_model_explain = model_explain.fit_benchmark(preprocessed_data_x, data_y_approved)
133
+
134
+ multi_inputs_train = get_multi_inputs(preprocessed_data_x, is_approval=False)
135
+
136
+ model_explain.compile(*multi_inputs_train, inputs_encryption_status=["encrypted", "encrypted", "encrypted"])
137
+
138
+ # Delete the deployment folder and its content if it already exists
139
+ if EXPLAIN_DEPLOYMENT_PATH.is_dir():
140
+ shutil.rmtree(EXPLAIN_DEPLOYMENT_PATH)
141
+
142
+ # Save files needed for deployment (and enable cross-platform deployment)
143
+ fhe_model_dev_explain = MultiInputsFHEModelDev(EXPLAIN_DEPLOYMENT_PATH, model_explain)
144
+ fhe_model_dev_explain.save(via_mlir=True)
145
+
146
  print("\nDone !")
server.py CHANGED
@@ -5,11 +5,11 @@ from typing import List, Optional
5
  from fastapi import FastAPI, File, Form, UploadFile
6
  from fastapi.responses import JSONResponse, Response
7
 
8
- from settings import DEPLOYMENT_PATH, SERVER_FILES, CLIENT_TYPES
9
  from utils.client_server_interface import MultiInputsFHEModelServer
10
 
11
  # Load the server
12
- FHE_SERVER = MultiInputsFHEModelServer(DEPLOYMENT_PATH)
13
 
14
 
15
  def _get_server_file_path(name, client_id, client_type=None):
 
5
  from fastapi import FastAPI, File, Form, UploadFile
6
  from fastapi.responses import JSONResponse, Response
7
 
8
+ from settings import APPROVAL_DEPLOYMENT_PATH, SERVER_FILES, CLIENT_TYPES
9
  from utils.client_server_interface import MultiInputsFHEModelServer
10
 
11
  # Load the server
12
+ FHE_SERVER = MultiInputsFHEModelServer(APPROVAL_DEPLOYMENT_PATH)
13
 
14
 
15
  def _get_server_file_path(name, client_id, client_type=None):
settings.py CHANGED
@@ -6,12 +6,16 @@ import pandas
6
  # The directory of this project
7
  REPO_DIR = Path(__file__).parent
8
 
9
- # This repository's main necessary directories
10
  DEPLOYMENT_PATH = REPO_DIR / "deployment_files"
11
  FHE_KEYS = REPO_DIR / ".fhe_keys"
12
  CLIENT_FILES = REPO_DIR / "client_files"
13
  SERVER_FILES = REPO_DIR / "server_files"
14
 
 
 
 
 
15
  # Path targeting pre-processor saved files
16
  PRE_PROCESSOR_USER_PATH = DEPLOYMENT_PATH / 'pre_processor_user.pkl'
17
  PRE_PROCESSOR_BANK_PATH = DEPLOYMENT_PATH / 'pre_processor_bank.pkl'
@@ -29,7 +33,8 @@ SERVER_URL = "http://localhost:8000/"
29
  DATA_PATH = "data/data.csv"
30
 
31
  # Development settings
32
- PROCESSED_INPUT_SHAPE = (1, 39)
 
33
 
34
  CLIENT_TYPES = ["user", "bank", "third_party"]
35
  INPUT_INDEXES = {
@@ -37,19 +42,26 @@ INPUT_INDEXES = {
37
  "bank": 1,
38
  "third_party": 2,
39
  }
40
- INPUT_SLICES = {
41
  "user": slice(0, 36), # First position: start from 0
42
  "bank": slice(36, 37), # Second position: start from n_feature_user
43
  "third_party": slice(37, 39), # Third position: start from n_feature_user + n_feature_bank
44
  }
 
 
 
 
 
45
 
 
46
  USER_COLUMNS = [
47
  'Own_car', 'Own_property', 'Mobile_phone', 'Num_children', 'Household_size',
48
  'Total_income', 'Age', 'Income_type', 'Education_type', 'Family_status', 'Housing_type',
49
  'Occupation_type',
50
  ]
51
  BANK_COLUMNS = ["Account_age"]
52
- THIRD_PARTY_COLUMNS = ["Years_employed", "Employed"]
 
53
 
54
  _data = pandas.read_csv(DATA_PATH, encoding="utf-8")
55
 
 
6
  # The directory of this project
7
  REPO_DIR = Path(__file__).parent
8
 
9
+ # Main necessary directories
10
  DEPLOYMENT_PATH = REPO_DIR / "deployment_files"
11
  FHE_KEYS = REPO_DIR / ".fhe_keys"
12
  CLIENT_FILES = REPO_DIR / "client_files"
13
  SERVER_FILES = REPO_DIR / "server_files"
14
 
15
+ # ALl deployment directories
16
+ APPROVAL_DEPLOYMENT_PATH = DEPLOYMENT_PATH / "approval_model"
17
+ EXPLAIN_DEPLOYMENT_PATH = DEPLOYMENT_PATH / "explain_model"
18
+
19
  # Path targeting pre-processor saved files
20
  PRE_PROCESSOR_USER_PATH = DEPLOYMENT_PATH / 'pre_processor_user.pkl'
21
  PRE_PROCESSOR_BANK_PATH = DEPLOYMENT_PATH / 'pre_processor_bank.pkl'
 
33
  DATA_PATH = "data/data.csv"
34
 
35
  # Development settings
36
+ APPROVAL_PROCESSED_INPUT_SHAPE = (1, 39)
37
+ EXPLAIN_PROCESSED_INPUT_SHAPE = (1, 38)
38
 
39
  CLIENT_TYPES = ["user", "bank", "third_party"]
40
  INPUT_INDEXES = {
 
42
  "bank": 1,
43
  "third_party": 2,
44
  }
45
+ APPROVAL_INPUT_SLICES = {
46
  "user": slice(0, 36), # First position: start from 0
47
  "bank": slice(36, 37), # Second position: start from n_feature_user
48
  "third_party": slice(37, 39), # Third position: start from n_feature_user + n_feature_bank
49
  }
50
+ EXPLAIN_INPUT_SLICES = {
51
+ "user": slice(0, 36), # First position: start from 0
52
+ "bank": slice(36, 37), # Second position: start from n_feature_user
53
+ "third_party": slice(37, 38), # Third position: start from n_feature_user + n_feature_bank
54
+ }
55
 
56
+ # Fix column order for pre-processing steps
57
  USER_COLUMNS = [
58
  'Own_car', 'Own_property', 'Mobile_phone', 'Num_children', 'Household_size',
59
  'Total_income', 'Age', 'Income_type', 'Education_type', 'Family_status', 'Housing_type',
60
  'Occupation_type',
61
  ]
62
  BANK_COLUMNS = ["Account_age"]
63
+ APPROVAL_THIRD_PARTY_COLUMNS = ["Years_employed", "Employed"]
64
+ EXPLAIN_THIRD_PARTY_COLUMNS = ["Employed"]
65
 
66
  _data = pandas.read_csv(DATA_PATH, encoding="utf-8")
67
 
utils/client_server_interface.py CHANGED
@@ -3,10 +3,11 @@
3
  import numpy
4
  import copy
5
 
6
- from concrete.fhe import Value, EvaluationKeys
7
 
 
8
  from concrete.ml.deployment.fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer
9
- from concrete.ml.sklearn import XGBClassifier as ConcreteXGBClassifier
10
 
11
 
12
  class MultiInputsFHEModelDev(FHEModelDev):
@@ -15,8 +16,9 @@ class MultiInputsFHEModelDev(FHEModelDev):
15
 
16
  super().__init__(*arg, **kwargs)
17
 
 
18
  model = copy.copy(self.model)
19
- model.__class__ = ConcreteXGBClassifier
20
  self.model = model
21
 
22
 
@@ -30,10 +32,27 @@ class MultiInputsFHEModelClient(FHEModelClient):
30
  def quantize_encrypt_serialize_multi_inputs(
31
  self,
32
  x: numpy.ndarray,
33
- input_index,
34
- processed_input_shape,
35
- input_slice
36
  ) -> bytes:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  x_padded = numpy.zeros(processed_input_shape)
39
 
@@ -58,15 +77,15 @@ class MultiInputsFHEModelServer(FHEModelServer):
58
 
59
  def run(
60
  self,
61
- *serialized_encrypted_quantized_data: bytes,
62
  serialized_evaluation_keys: bytes,
63
  ) -> bytes:
64
- """Run the model on the server over encrypted data.
65
 
66
  Args:
67
- serialized_encrypted_quantized_data (bytes): the encrypted, quantized
68
- and serialized data
69
- serialized_evaluation_keys (bytes): the serialized evaluation keys
70
 
71
  Returns:
72
  bytes: the result of the model
 
3
  import numpy
4
  import copy
5
 
6
+ from typing import Tuple
7
 
8
+ from concrete.fhe import Value, EvaluationKeys
9
  from concrete.ml.deployment.fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer
10
+ from concrete.ml.sklearn import DecisionTreeClassifier
11
 
12
 
13
  class MultiInputsFHEModelDev(FHEModelDev):
 
16
 
17
  super().__init__(*arg, **kwargs)
18
 
19
+ # Workaround that enables loading a modified version of a DecisionTreeClassifier model
20
  model = copy.copy(self.model)
21
+ model.__class__ = DecisionTreeClassifier
22
  self.model = model
23
 
24
 
 
32
  def quantize_encrypt_serialize_multi_inputs(
33
  self,
34
  x: numpy.ndarray,
35
+ input_index: int,
36
+ processed_input_shape: Tuple[int],
37
+ input_slice: slice,
38
  ) -> bytes:
39
+ """Quantize, encrypt and serialize inputs for a multi-party model.
40
+
41
+ In the following, the 'quantize_input' method called is the one defined in Concrete ML's
42
+ built-in models. Since they don't natively handle inputs for multi-party models, we need
43
+ to use padding along indexing and slicing so that inputs from a specific party are correctly
44
+ associated with input quantizers.
45
+
46
+ Args:
47
+ x (numpy.ndarray): The input to consider. Here, the input should only represent a
48
+ single party.
49
+ input_index (int): The index representing the type of model (0: "user", 1: "bank",
50
+ 2: "third_party")
51
+ processed_input_shape (Tuple[int]): The total input shape (all parties combined) after
52
+ pre-processing.
53
+ input_slice (slice): The slices to consider for the given party.
54
+
55
+ """
56
 
57
  x_padded = numpy.zeros(processed_input_shape)
58
 
 
77
 
78
  def run(
79
  self,
80
+ *serialized_encrypted_quantized_data: Tuple[bytes],
81
  serialized_evaluation_keys: bytes,
82
  ) -> bytes:
83
+ """Run the model on the server over encrypted data for a multi-party model.
84
 
85
  Args:
86
+ serialized_encrypted_quantized_data (Tuple[bytes]): The encrypted, quantized
87
+ and serialized data for a multi-party model.
88
+ serialized_evaluation_keys (bytes): The serialized evaluation key.
89
 
90
  Returns:
91
  bytes: the result of the model
utils/model.py CHANGED
@@ -13,7 +13,7 @@ from concrete.ml.common.utils import (
13
  check_there_is_no_p_error_options_in_configuration
14
  )
15
  from concrete.ml.quantization.quantized_module import QuantizedModule, _get_inputset_generator
16
- from concrete.ml.sklearn import DecisionTreeClassifier
17
 
18
  class MultiInputModel:
19
 
@@ -131,46 +131,8 @@ class MultiInputModel:
131
 
132
  return compiler
133
 
134
- def predict_multi_inputs(self, *multi_inputs, simulate=True):
135
- """Run the inference with multiple inputs, with simulation or in FHE."""
136
- assert all(isinstance(inputs, numpy.ndarray) for inputs in multi_inputs)
137
-
138
- if not simulate:
139
- self.fhe_circuit.keygen()
140
-
141
- y_preds = []
142
- execution_times = []
143
- for inputs in zip(*multi_inputs):
144
- inputs = tuple(numpy.expand_dims(input, axis=0) for input in inputs)
145
-
146
- q_inputs = self.quantize_input(*inputs)
147
-
148
- if simulate:
149
- q_y_proba = self.fhe_circuit.simulate(*q_inputs)
150
- else:
151
- q_inputs_enc = self.fhe_circuit.encrypt(*q_inputs)
152
-
153
- start = time.time()
154
- q_y_proba_enc = self.fhe_circuit.run(*q_inputs_enc)
155
- end = time.time() - start
156
-
157
- execution_times.append(end)
158
-
159
- q_y_proba = self.fhe_circuit.decrypt(q_y_proba_enc)
160
-
161
- y_proba = self.dequantize_output(q_y_proba)
162
-
163
- y_proba = self.post_processing(y_proba)
164
-
165
- y_pred = numpy.argmax(y_proba, axis=1)
166
-
167
- y_preds.append(y_pred)
168
-
169
- if not simulate:
170
- print(f"FHE execution time per inference: {numpy.mean(execution_times) :.2}s")
171
-
172
- return numpy.array(y_preds)
173
-
174
-
175
  class MultiInputDecisionTreeClassifier(MultiInputModel, DecisionTreeClassifier):
176
- pass
 
 
 
 
13
  check_there_is_no_p_error_options_in_configuration
14
  )
15
  from concrete.ml.quantization.quantized_module import QuantizedModule, _get_inputset_generator
16
+ from concrete.ml.sklearn import DecisionTreeClassifier, DecisionTreeRegressor
17
 
18
  class MultiInputModel:
19
 
 
131
 
132
  return compiler
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  class MultiInputDecisionTreeClassifier(MultiInputModel, DecisionTreeClassifier):
135
+ pass
136
+
137
+ class MultiInputDecisionTreeRegressor(MultiInputModel, DecisionTreeRegressor):
138
+ pass