romanbredehoft-zama commited on
Commit
9a997e4
β€’
1 Parent(s): 4337a72

First working demo with multi-inputs XGB

Browse files
app.py CHANGED
@@ -1,318 +1,39 @@
1
- """A local gradio app that filters images using FHE."""
2
 
3
- import os
4
- import shutil
5
  import subprocess
6
  import time
7
  import gradio as gr
8
- import numpy
9
- import requests
10
- from itertools import chain
11
 
12
  from settings import (
13
  REPO_DIR,
14
- SERVER_URL,
15
- FHE_KEYS,
16
- CLIENT_FILES,
17
- SERVER_FILES,
18
- DEPLOYMENT_PATH,
19
- INITIAL_INPUT_SHAPE,
20
- INPUT_INDEXES,
21
- START_POSITIONS,
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  )
23
-
24
- from development.client_server_interface import MultiInputsFHEModelClient
25
 
26
 
27
  subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
28
  time.sleep(3)
29
 
30
 
31
- def shorten_bytes_object(bytes_object, limit=500):
32
- """Shorten the input bytes object to a given length.
33
-
34
- Encrypted data is too large for displaying it in the browser using Gradio. This function
35
- provides a shorten representation of it.
36
-
37
- Args:
38
- bytes_object (bytes): The input to shorten
39
- limit (int): The length to consider. Default to 500.
40
-
41
- Returns:
42
- str: Hexadecimal string shorten representation of the input byte object.
43
-
44
- """
45
- # Define a shift for better display
46
- shift = 100
47
- return bytes_object[shift : limit + shift].hex()
48
-
49
-
50
- def get_client(client_id, client_type):
51
- """Get the client API.
52
-
53
- Args:
54
- client_id (int): The client ID to consider.
55
- client_type (str): The type of user to consider (either 'user', 'bank' or 'third_party').
56
-
57
- Returns:
58
- FHEModelClient: The client API.
59
- """
60
- key_dir = FHE_KEYS / f"{client_type}_{client_id}"
61
-
62
- return MultiInputsFHEModelClient(DEPLOYMENT_PATH, key_dir=key_dir)
63
-
64
-
65
- def get_client_file_path(name, client_id, client_type):
66
- """Get the correct temporary file path for the client.
67
-
68
- Args:
69
- name (str): The desired file name (either 'evaluation_key' or 'encrypted_inputs').
70
- client_id (int): The client ID to consider.
71
- client_type (str): The type of user to consider (either 'user', 'bank' or 'third_party').
72
-
73
- Returns:
74
- pathlib.Path: The file path.
75
- """
76
- return CLIENT_FILES / f"{name}_{client_type}_{client_id}"
77
-
78
-
79
- def clean_temporary_files(n_keys=20):
80
- """Clean keys and encrypted images.
81
-
82
- A maximum of n_keys keys and associated temporary files are allowed to be stored. Once this
83
- limit is reached, the oldest files are deleted.
84
-
85
- Args:
86
- n_keys (int): The maximum number of keys and associated files to be stored. Default to 20.
87
-
88
- """
89
- # Get the oldest key files in the key directory
90
- key_dirs = sorted(FHE_KEYS.iterdir(), key=os.path.getmtime)
91
-
92
- # If more than n_keys keys are found, remove the oldest
93
- user_ids = []
94
- if len(key_dirs) > n_keys:
95
- n_keys_to_delete = len(key_dirs) - n_keys
96
- for key_dir in key_dirs[:n_keys_to_delete]:
97
- user_ids.append(key_dir.name)
98
- shutil.rmtree(key_dir)
99
-
100
- # Get all the encrypted objects in the temporary folder
101
- client_files = CLIENT_FILES.iterdir()
102
- server_files = SERVER_FILES.iterdir()
103
-
104
- # Delete all files related to the ids whose keys were deleted
105
- for file in chain(client_files, server_files):
106
- for user_id in user_ids:
107
- if user_id in file.name:
108
- file.unlink()
109
-
110
-
111
- def keygen(client_id, client_type):
112
- """Generate the private key associated to a filter.
113
-
114
- Args:
115
- client_id (int): The client ID to consider.
116
- client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
117
- """
118
- # Clean temporary files
119
- clean_temporary_files()
120
-
121
- # Retrieve the client instance
122
- client = get_client(client_id, client_type)
123
-
124
- # Generate a private key
125
- client.generate_private_and_evaluation_keys(force=True)
126
-
127
- # Retrieve the serialized evaluation key. In this case, as circuits are fully leveled, this
128
- # evaluation key is empty. However, for software reasons, it is still needed for proper FHE
129
- # execution
130
- evaluation_key = client.get_serialized_evaluation_keys()
131
-
132
- # Save evaluation_key as bytes in a file as it is too large to pass through regular Gradio
133
- # buttons (see https://github.com/gradio-app/gradio/issues/1877)
134
- evaluation_key_path = get_client_file_path("evaluation_key", client_id, client_type)
135
-
136
- with evaluation_key_path.open("wb") as evaluation_key_file:
137
- evaluation_key_file.write(evaluation_key)
138
-
139
-
140
- def send_input(client_id, client_type):
141
- """Send the encrypted input image as well as the evaluation key to the server.
142
-
143
- Args:
144
- client_id (int): The client ID to consider.
145
- client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
146
- """
147
- # Get the paths to the evaluation key and encrypted inputs
148
- evaluation_key_path = get_client_file_path("evaluation_key", client_id, client_type)
149
- encrypted_input_path = get_client_file_path("encrypted_inputs", client_id, client_type)
150
-
151
- # Define the data and files to post
152
- data = {
153
- "client_id": client_id,
154
- "client_type": client_type,
155
- }
156
-
157
- files = [
158
- ("files", open(encrypted_input_path, "rb")),
159
- ("files", open(evaluation_key_path, "rb")),
160
- ]
161
-
162
- # Send the encrypted input image and evaluation key to the server
163
- url = SERVER_URL + "send_input"
164
- with requests.post(
165
- url=url,
166
- data=data,
167
- files=files,
168
- ) as response:
169
- return response.ok
170
-
171
-
172
- def keygen_encrypt_send(inputs, client_type):
173
- """Encrypt the given inputs for a specific client.
174
-
175
- Args:
176
- inputs (numpy.ndarray): The inputs to encrypt.
177
- client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
178
-
179
- Returns:
180
-
181
- """
182
- # Create an ID for the current client to consider
183
- client_id = numpy.random.randint(0, 2**32)
184
-
185
- keygen(client_id, client_type)
186
-
187
- # Retrieve the client instance
188
- client = get_client(client_id, client_type)
189
-
190
- # TODO : pre-process the data first
191
-
192
- # Quantize, encrypt and serialize the inputs
193
- encrypted_inputs = client.quantize_encrypt_serialize_multi_inputs(
194
- inputs,
195
- input_index=INPUT_INDEXES[client_type],
196
- initial_input_shape=INITIAL_INPUT_SHAPE,
197
- start_position=START_POSITIONS[client_type],
198
- )
199
-
200
- # Save encrypted_inputs to bytes in a file, since too large to pass through regular Gradio
201
- # buttons, https://github.com/gradio-app/gradio/issues/1877
202
- encrypted_inputs_path = get_client_file_path("encrypted_inputs", client_id, client_type)
203
-
204
- with encrypted_inputs_path.open("wb") as encrypted_inputs_file:
205
- encrypted_inputs_file.write(encrypted_inputs)
206
-
207
- # Create a truncated version of the encrypted image for display
208
- encrypted_inputs_short = shorten_bytes_object(encrypted_inputs)
209
-
210
- send_input(client_id, client_type)
211
-
212
- # TODO: also return private key representation if possible
213
- return encrypted_inputs_short
214
-
215
-
216
- def run_fhe(client_id):
217
- """Run the model on the encrypted inputs previously sent using FHE.
218
-
219
- Args:
220
- client_id (int): The client ID to consider.
221
- """
222
-
223
- # TODO : add a warning for users to send all client types' inputs
224
-
225
- data = {
226
- "client_id": client_id,
227
- }
228
-
229
- # Trigger the FHE execution on the encrypted inputs previously sent
230
- url = SERVER_URL + "run_fhe"
231
- with requests.post(
232
- url=url,
233
- data=data,
234
- ) as response:
235
- if response.ok:
236
- return response.json()
237
- else:
238
- raise gr.Error("Please wait for the inputs to be sent to the server.")
239
-
240
-
241
- def get_output(client_id):
242
- """Retrieve the encrypted output.
243
-
244
- Args:
245
- client_id (int): The client ID to consider.
246
-
247
- Returns:
248
- output_encrypted_representation (numpy.ndarray): A representation of the encrypted output.
249
-
250
- """
251
- data = {
252
- "client_id": client_id,
253
- }
254
-
255
- # Retrieve the encrypted output image
256
- url = SERVER_URL + "get_output"
257
- with requests.post(
258
- url=url,
259
- data=data,
260
- ) as response:
261
- if response.ok:
262
- encrypted_output = response.content
263
-
264
- # Save the encrypted output to bytes in a file as it is too large to pass through regular
265
- # Gradio buttons (see https://github.com/gradio-app/gradio/issues/1877)
266
- # TODO : check if output to user is relevant
267
- encrypted_output_path = get_client_file_path("encrypted_output", client_id, "user")
268
-
269
- with encrypted_output_path.open("wb") as encrypted_output_file:
270
- encrypted_output_file.write(encrypted_output)
271
-
272
- # TODO
273
- # Decrypt the output using a different (wrong) key for display
274
- # output_encrypted_representation = decrypt_output_with_wrong_key(encrypted_output, client_type)
275
-
276
- # return output_encrypted_representation
277
-
278
- return None
279
- else:
280
- raise gr.Error("Please wait for the FHE execution to be completed.")
281
-
282
-
283
- def decrypt_output(client_id, client_type):
284
- """Decrypt the result.
285
-
286
- Args:
287
- client_id (int): The client ID to consider.
288
- client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
289
-
290
- Returns:
291
- output(numpy.ndarray): The decrypted output
292
-
293
- """
294
- # Get the encrypted output path
295
- encrypted_output_path = get_client_file_path("encrypted_output", client_id, client_type)
296
-
297
- if not encrypted_output_path.is_file():
298
- raise gr.Error("Please run the FHE execution first.")
299
-
300
- # Load the encrypted output as bytes
301
- with encrypted_output_path.open("rb") as encrypted_output_file:
302
- encrypted_output_proba = encrypted_output_file.read()
303
-
304
- # Retrieve the client API
305
- client = get_client(client_id, client_type)
306
-
307
- # Deserialize, decrypt and post-process the encrypted output
308
- output_proba = client.deserialize_decrypt_post_process(encrypted_output_proba)
309
-
310
- # Determine the predicted class
311
- output = numpy.argmax(output_proba, axis=1)
312
-
313
- return output
314
-
315
-
316
  demo = gr.Blocks()
317
 
318
 
@@ -330,60 +51,68 @@ with demo:
330
  with gr.Row():
331
  with gr.Column():
332
  gr.Markdown("### User")
333
- # TODO : change infos
334
- choice_1 = gr.Dropdown(choices=["Yes, No"], label="Choose", interactive=True)
335
- slide_1 = gr.Slider(2, 20, value=4, label="Count", info="Choose between 2 and 20")
 
 
 
 
 
 
 
 
336
 
337
  with gr.Column():
338
  gr.Markdown("### Bank ")
339
- # TODO : change infos
340
- checkbox_1 = gr.CheckboxGroup(["USA", "Japan", "Pakistan"], label="Countries", info="Where are they from?")
341
 
342
  with gr.Column():
343
- gr.Markdown("### Third Party ")
344
- # TODO : change infos
345
- radio_1 = gr.Radio(["park", "zoo", "road"], label="Location", info="Where did they go?")
346
-
347
 
348
  gr.Markdown("### Step 2: Keygen, encrypt using FHE and send the inputs to the server.")
349
  with gr.Row():
350
  with gr.Column():
351
  gr.Markdown("### User")
352
  encrypt_button_user = gr.Button("Encrypt the inputs and send to server.")
353
- keys_user = gr.Textbox(
354
- label="Keys representation:", max_lines=2, interactive=False
355
- )
356
  encrypted_input_user = gr.Textbox(
357
  label="Encrypted input representation:", max_lines=2, interactive=False
358
  )
 
 
 
359
 
360
- user_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False)
361
 
362
 
363
  with gr.Column():
364
  gr.Markdown("### Bank ")
365
  encrypt_button_bank = gr.Button("Encrypt the inputs and send to server.")
366
- keys_bank = gr.Textbox(
367
- label="Keys representation:", max_lines=2, interactive=False
368
- )
369
  encrypted_input_bank = gr.Textbox(
370
  label="Encrypted input representation:", max_lines=2, interactive=False
371
  )
372
-
373
- bank_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False)
 
374
 
375
 
376
  with gr.Column():
377
  gr.Markdown("### Third Party ")
378
  encrypt_button_third_party = gr.Button("Encrypt the inputs and send to server.")
379
- keys_3 = gr.Textbox(
380
- label="Keys representation:", max_lines=2, interactive=False
381
- )
382
- encrypted_input__third_party = gr.Textbox(
383
- label="Encrypted input representation:", max_lines=2, interactive=False
384
- )
385
 
386
  third_party_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False)
 
 
 
 
 
 
387
 
388
  gr.Markdown("## Server side")
389
  gr.Markdown(
@@ -412,9 +141,9 @@ with demo:
412
  )
413
  get_output_button = gr.Button("Receive the encrypted output from the server.")
414
 
415
- encrypted_output_representation = gr.Textbox(
416
- label="Encrypted output representation: ", max_lines=1, interactive=False
417
- )
418
 
419
  gr.Markdown("### Step 8: Decrypt the output.")
420
  decrypt_button = gr.Button("Decrypt the output")
@@ -423,48 +152,50 @@ with demo:
423
  label="Credit card approval decision: ", max_lines=1, interactive=False
424
  )
425
 
426
- # Button to encrypt inputs on the client side
427
- # encrypt_button_user.click(
428
- # encrypt,
429
- # inputs=[user_id, input_image, filter_name],
430
- # outputs=[original_image, encrypted_input],
431
- # )
432
-
433
- # # Button to encrypt inputs on the client side
434
- # encrypt_button_bank.click(
435
- # encrypt,
436
- # inputs=[user_id, input_image, filter_name],
437
- # outputs=[original_image, encrypted_input],
438
- # )
439
 
440
- # # Button to encrypt inputs on the client side
441
- # encrypt_button_third_party.click(
442
- # encrypt,
443
- # inputs=[user_id, input_image, filter_name],
444
- # outputs=[original_image, encrypted_input],
445
- # )
 
446
 
447
- # # Button to send the encodings to the server using post method
448
- # send_input_button.click(
449
- # send_input, inputs=[user_id, filter_name], outputs=[send_input_checkbox]
450
- # )
 
 
 
451
 
452
- # # Button to send the encodings to the server using post method
453
- # execute_fhe_button.click(run_fhe, inputs=[user_id, filter_name], outputs=[fhe_execution_time])
 
454
 
455
- # # Button to send the encodings to the server using post method
456
- # get_output_button.click(
457
- # get_output,
458
- # inputs=[user_id, filter_name],
459
- # outputs=[encrypted_output_representation]
460
- # )
 
461
 
462
- # # Button to decrypt the output on the client side
463
- # decrypt_button.click(
464
- # decrypt_output,
465
- # inputs=[user_id, filter_name],
466
- # outputs=[output_image, keygen_checkbox, send_input_checkbox],
467
- # )
 
468
 
469
  gr.Markdown(
470
  "The app was built with [Concrete-ML](https://github.com/zama-ai/concrete-ml), a "
 
1
+ """A gradio app for credit card approval prediction using FHE."""
2
 
 
 
3
  import subprocess
4
  import time
5
  import gradio as gr
 
 
 
6
 
7
  from settings import (
8
  REPO_DIR,
9
+ ACCOUNT_MIN_MAX,
10
+ CHILDREN_MIN_MAX,
11
+ INCOME_MIN_MAX,
12
+ AGE_MIN_MAX,
13
+ EMPLOYED_MIN_MAX,
14
+ FAMILY_MIN_MAX,
15
+ INCOME_TYPES,
16
+ OCCUPATION_TYPES,
17
+ HOUSING_TYPES,
18
+ EDUCATION_TYPES,
19
+ FAMILY_STATUS,
20
+ )
21
+ from backend import (
22
+ shorten_bytes_object,
23
+ clean_temporary_files,
24
+ pre_process_keygen_encrypt_send_user,
25
+ pre_process_keygen_encrypt_send_bank,
26
+ pre_process_keygen_encrypt_send_third_party,
27
+ run_fhe,
28
+ get_output,
29
+ decrypt_output,
30
  )
 
 
31
 
32
 
33
  subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
34
  time.sleep(3)
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  demo = gr.Blocks()
38
 
39
 
 
51
  with gr.Row():
52
  with gr.Column():
53
  gr.Markdown("### User")
54
+ gender = gr.Radio(["Female", "Male"], label="Gender")
55
+ bool_inputs = gr.CheckboxGroup(["Car", "Property", "Work phone", "Phone", "Email"], label="What do you own ?")
56
+ num_children = gr.Slider(**CHILDREN_MIN_MAX, step=1, label="Number of children", info="How many children do you have (0 to 19) ?")
57
+ num_family = gr.Slider(**FAMILY_MIN_MAX, step=1, label="Family", info="How many members does your family have? (1 to 20) ?")
58
+ total_income = gr.Slider(**INCOME_MIN_MAX, label="Income", info="What's you total yearly income (in euros, 3780 to 220500) ?")
59
+ age = gr.Slider(**AGE_MIN_MAX, step=1, label="Age", info="How old are you (20 to 68) ?")
60
+ income_type = gr.Dropdown(choices=INCOME_TYPES, label="Income type", info="What is your main type of income ?")
61
+ education_type = gr.Dropdown(choices=EDUCATION_TYPES, label="Education", info="What is your education background ?")
62
+ family_status = gr.Dropdown(choices=FAMILY_STATUS, label="Family", info="What is your family status ?")
63
+ occupation_type = gr.Dropdown(choices=OCCUPATION_TYPES, label="Occupation", info="What is your main occupation ?")
64
+ housing_type = gr.Dropdown(choices=HOUSING_TYPES, label="Housing", info="In what type of housing do you live ?")
65
 
66
  with gr.Column():
67
  gr.Markdown("### Bank ")
68
+ account_length = gr.Slider(**ACCOUNT_MIN_MAX, step=1, label="Account length", info="How long have this person had this account (in months, 0 to 60) ?")
 
69
 
70
  with gr.Column():
71
+ gr.Markdown("### Third party ")
72
+ employed = gr.Radio(["Yes", "No"], label="Is the person employed ?")
73
+ years_employed = gr.Slider(**EMPLOYED_MIN_MAX, step=1, label="Years of employment", info="How long have this person been employed (in years, 0 to 43) ?")
74
+
75
 
76
  gr.Markdown("### Step 2: Keygen, encrypt using FHE and send the inputs to the server.")
77
  with gr.Row():
78
  with gr.Column():
79
  gr.Markdown("### User")
80
  encrypt_button_user = gr.Button("Encrypt the inputs and send to server.")
81
+
82
+ user_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False)
 
83
  encrypted_input_user = gr.Textbox(
84
  label="Encrypted input representation:", max_lines=2, interactive=False
85
  )
86
+ # keys_user = gr.Textbox(
87
+ # label="Keys representation:", max_lines=2, interactive=False
88
+ # )
89
 
 
90
 
91
 
92
  with gr.Column():
93
  gr.Markdown("### Bank ")
94
  encrypt_button_bank = gr.Button("Encrypt the inputs and send to server.")
95
+
96
+ bank_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False)
 
97
  encrypted_input_bank = gr.Textbox(
98
  label="Encrypted input representation:", max_lines=2, interactive=False
99
  )
100
+ # keys_bank = gr.Textbox(
101
+ # label="Keys representation:", max_lines=2, interactive=False
102
+ # )
103
 
104
 
105
  with gr.Column():
106
  gr.Markdown("### Third Party ")
107
  encrypt_button_third_party = gr.Button("Encrypt the inputs and send to server.")
 
 
 
 
 
 
108
 
109
  third_party_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False)
110
+ encrypted_input_third_party = gr.Textbox(
111
+ label="Encrypted input representation:", max_lines=2, interactive=False
112
+ )
113
+ # keys_3 = gr.Textbox(
114
+ # label="Keys representation:", max_lines=2, interactive=False
115
+ # )
116
 
117
  gr.Markdown("## Server side")
118
  gr.Markdown(
 
141
  )
142
  get_output_button = gr.Button("Receive the encrypted output from the server.")
143
 
144
+ # encrypted_output_representation = gr.Textbox(
145
+ # label="Encrypted output representation: ", max_lines=1, interactive=False
146
+ # )
147
 
148
  gr.Markdown("### Step 8: Decrypt the output.")
149
  decrypt_button = gr.Button("Decrypt the output")
 
152
  label="Credit card approval decision: ", max_lines=1, interactive=False
153
  )
154
 
155
+ # Button to pre-process, generate the key, encrypt and send the user inputs from the client
156
+ # side to the server
157
+ encrypt_button_user.click(
158
+ pre_process_keygen_encrypt_send_user,
159
+ inputs=[gender, bool_inputs, num_children, num_family, total_income, age, income_type, \
160
+ education_type, family_status, occupation_type, housing_type],
161
+ outputs=[user_id, encrypted_input_user],
162
+ )
 
 
 
 
 
163
 
164
+ # Button to pre-process, generate the key, encrypt and send the bank inputs from the client
165
+ # side to the server
166
+ encrypt_button_bank.click(
167
+ pre_process_keygen_encrypt_send_bank,
168
+ inputs=[account_length],
169
+ outputs=[bank_id, encrypted_input_bank],
170
+ )
171
 
172
+ # Button to pre-process, generate the key, encrypt and send the third party inputs from the
173
+ # client side to the server
174
+ encrypt_button_third_party.click(
175
+ pre_process_keygen_encrypt_send_third_party,
176
+ inputs=[employed, years_employed],
177
+ outputs=[third_party_id, encrypted_input_third_party],
178
+ )
179
 
180
+ # TODO : ID should be unique
181
+ # Button to send the encodings to the server using post method
182
+ execute_fhe_button.click(run_fhe, inputs=[user_id, bank_id, third_party_id], outputs=[fhe_execution_time])
183
 
184
+ # TODO : ID should be unique
185
+ # Button to send the encodings to the server using post method
186
+ get_output_button.click(
187
+ get_output,
188
+ inputs=[user_id, bank_id, third_party_id],
189
+ # outputs=[encrypted_output_representation]
190
+ )
191
 
192
+ # TODO : ID should be unique
193
+ # Button to decrypt the output as the user
194
+ decrypt_button.click(
195
+ decrypt_output,
196
+ inputs=[user_id, bank_id, third_party_id],
197
+ outputs=[prediction_output],
198
+ )
199
 
200
  gr.Markdown(
201
  "The app was built with [Concrete-ML](https://github.com/zama-ai/concrete-ml), a "
backend.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Backend functions used in the app."""
2
+
3
+ import os
4
+ import shutil
5
+ import gradio as gr
6
+ import numpy
7
+ import requests
8
+ import pickle
9
+ import pandas
10
+ from itertools import chain
11
+
12
+ from settings import (
13
+ SERVER_URL,
14
+ FHE_KEYS,
15
+ CLIENT_FILES,
16
+ SERVER_FILES,
17
+ DEPLOYMENT_PATH,
18
+ INITIAL_INPUT_SHAPE,
19
+ INPUT_INDEXES,
20
+ INPUT_SLICES,
21
+ PRE_PROCESSOR_USER_PATH,
22
+ PRE_PROCESSOR_THIRD_PARTY_PATH,
23
+ CLIENT_TYPES,
24
+ )
25
+
26
+ from utils.client_server_interface import MultiInputsFHEModelClient
27
+
28
+ # Load pre-processor instances
29
+ with PRE_PROCESSOR_USER_PATH.open('rb') as file:
30
+ PRE_PROCESSOR_USER = pickle.load(file)
31
+
32
+ with PRE_PROCESSOR_THIRD_PARTY_PATH.open('rb') as file:
33
+ PRE_PROCESSOR_THIRD_PARTY = pickle.load(file)
34
+
35
+
36
+ def shorten_bytes_object(bytes_object, limit=500):
37
+ """Shorten the input bytes object to a given length.
38
+
39
+ Encrypted data is too large for displaying it in the browser using Gradio. This function
40
+ provides a shorten representation of it.
41
+
42
+ Args:
43
+ bytes_object (bytes): The input to shorten
44
+ limit (int): The length to consider. Default to 500.
45
+
46
+ Returns:
47
+ str: Hexadecimal string shorten representation of the input byte object.
48
+
49
+ """
50
+ # Define a shift for better display
51
+ shift = 100
52
+ return bytes_object[shift : limit + shift].hex()
53
+
54
+
55
+ def clean_temporary_files(n_keys=20):
56
+ """Clean keys and encrypted images.
57
+
58
+ A maximum of n_keys keys and associated temporary files are allowed to be stored. Once this
59
+ limit is reached, the oldest files are deleted.
60
+
61
+ Args:
62
+ n_keys (int): The maximum number of keys and associated files to be stored. Default to 20.
63
+
64
+ """
65
+ # Get the oldest key files in the key directory
66
+ key_dirs = sorted(FHE_KEYS.iterdir(), key=os.path.getmtime)
67
+
68
+ # If more than n_keys keys are found, remove the oldest
69
+ user_ids = []
70
+ if len(key_dirs) > n_keys:
71
+ n_keys_to_delete = len(key_dirs) - n_keys
72
+ for key_dir in key_dirs[:n_keys_to_delete]:
73
+ user_ids.append(key_dir.name)
74
+ shutil.rmtree(key_dir)
75
+
76
+ # Get all the encrypted objects in the temporary folder
77
+ client_files = CLIENT_FILES.iterdir()
78
+ server_files = SERVER_FILES.iterdir()
79
+
80
+ # Delete all files related to the ids whose keys were deleted
81
+ for file in chain(client_files, server_files):
82
+ for user_id in user_ids:
83
+ if user_id in file.name:
84
+ file.unlink()
85
+
86
+
87
+ def _get_client(client_id, client_type):
88
+ """Get the client API.
89
+
90
+ Args:
91
+ client_id (int): The client ID to consider.
92
+ client_type (str): The type of user to consider (either 'user', 'bank' or 'third_party').
93
+
94
+ Returns:
95
+ FHEModelClient: The client API.
96
+ """
97
+ key_dir = FHE_KEYS / f"{client_type}_{client_id}"
98
+
99
+ return MultiInputsFHEModelClient(DEPLOYMENT_PATH, key_dir=key_dir, nb_inputs=len(CLIENT_TYPES))
100
+
101
+
102
+ def _keygen(client_id, client_type):
103
+ """Generate the private key associated to a filter.
104
+
105
+ Args:
106
+ client_id (int): The client ID to consider.
107
+ client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
108
+ """
109
+ # Clean temporary files
110
+ clean_temporary_files()
111
+
112
+ # Retrieve the client instance
113
+ client = _get_client(client_id, client_type)
114
+
115
+ # Generate a private key
116
+ client.generate_private_and_evaluation_keys(force=True)
117
+
118
+ # Retrieve the serialized evaluation key. In this case, as circuits are fully leveled, this
119
+ # evaluation key is empty. However, for software reasons, it is still needed for proper FHE
120
+ # execution
121
+ evaluation_key = client.get_serialized_evaluation_keys()
122
+
123
+ # Save evaluation_key as bytes in a file as it is too large to pass through regular Gradio
124
+ # buttons (see https://github.com/gradio-app/gradio/issues/1877)
125
+ evaluation_key_path = _get_client_file_path("evaluation_key", client_id, client_type)
126
+
127
+ with evaluation_key_path.open("wb") as evaluation_key_file:
128
+ evaluation_key_file.write(evaluation_key)
129
+
130
+
131
+ def _send_input(client_id, client_type):
132
+ """Send the encrypted input image as well as the evaluation key to the server.
133
+
134
+ Args:
135
+ client_id (int): The client ID to consider.
136
+ client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
137
+ """
138
+ # Get the paths to the evaluation key and encrypted inputs
139
+ evaluation_key_path = _get_client_file_path("evaluation_key", client_id, client_type)
140
+ encrypted_input_path = _get_client_file_path("encrypted_inputs", client_id, client_type)
141
+
142
+ # Define the data and files to post
143
+ data = {
144
+ "client_id": client_id,
145
+ "client_type": client_type,
146
+ }
147
+
148
+ files = [
149
+ ("files", open(encrypted_input_path, "rb")),
150
+ ("files", open(evaluation_key_path, "rb")),
151
+ ]
152
+
153
+ # Send the encrypted input image and evaluation key to the server
154
+ url = SERVER_URL + "send_input"
155
+ with requests.post(
156
+ url=url,
157
+ data=data,
158
+ files=files,
159
+ ) as response:
160
+ return response.ok
161
+
162
+
163
+ def _get_client_file_path(name, client_id, client_type):
164
+ """Get the correct temporary file path for the client.
165
+
166
+ Args:
167
+ name (str): The desired file name (either 'evaluation_key' or 'encrypted_inputs').
168
+ client_id (int): The client ID to consider.
169
+ client_type (str): The type of user to consider (either 'user', 'bank' or 'third_party').
170
+
171
+ Returns:
172
+ pathlib.Path: The file path.
173
+ """
174
+ return CLIENT_FILES / f"{name}_{client_type}_{client_id}"
175
+
176
+
177
+ def _keygen_encrypt_send(inputs, client_type):
178
+ """Encrypt the given inputs for a specific client.
179
+
180
+ Args:
181
+ inputs (numpy.ndarray): The inputs to encrypt.
182
+ client_type (str): The type of client to consider (either 'user', 'bank' or 'third_party').
183
+
184
+ Returns:
185
+ client_id, encrypted_inputs_short (int, bytes): Integer ID representing the current client
186
+ and a byte short representation of the encrypted input to send.
187
+ """
188
+ # Create an ID for the current client to consider
189
+ client_id = numpy.random.randint(0, 2**32)
190
+
191
+ _keygen(client_id, client_type)
192
+
193
+ # Retrieve the client instance
194
+ client = _get_client(client_id, client_type)
195
+
196
+ # TODO : pre-process the data first
197
+
198
+ # Quantize, encrypt and serialize the inputs
199
+ encrypted_inputs = client.quantize_encrypt_serialize_multi_inputs(
200
+ inputs,
201
+ input_index=INPUT_INDEXES[client_type],
202
+ initial_input_shape=INITIAL_INPUT_SHAPE,
203
+ input_slice=INPUT_SLICES[client_type],
204
+ )
205
+
206
+ # Save encrypted_inputs to bytes in a file, since too large to pass through regular Gradio
207
+ # buttons, https://github.com/gradio-app/gradio/issues/1877
208
+ encrypted_inputs_path = _get_client_file_path("encrypted_inputs", client_id, client_type)
209
+
210
+ with encrypted_inputs_path.open("wb") as encrypted_inputs_file:
211
+ encrypted_inputs_file.write(encrypted_inputs)
212
+
213
+ # Create a truncated version of the encrypted image for display
214
+ encrypted_inputs_short = shorten_bytes_object(encrypted_inputs)
215
+
216
+ _send_input(client_id, client_type)
217
+
218
+ # TODO: also return private key representation if possible
219
+ return client_id, encrypted_inputs_short
220
+
221
+
222
+ def pre_process_keygen_encrypt_send_user(*inputs):
223
+ """Pre-process the given inputs for a specific client.
224
+
225
+ Args:
226
+ *inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
227
+
228
+ Returns:
229
+ (int, bytes): Integer ID representing the current client and a byte short representation of
230
+ the encrypted input to send.
231
+ """
232
+ gender, bool_inputs, num_children, num_family, total_income, age, income_type, education_type, \
233
+ family_status, occupation_type, housing_type = inputs
234
+
235
+ # Encoding given in https://www.kaggle.com/code/samuelcortinhas/credit-cards-data-cleaning
236
+ # for "Gender" is M ('Male') -> 1 and F ('Female') -> 0
237
+ gender = gender == "Male"
238
+
239
+ # Retrieve boolean values
240
+ own_car = "Car" in bool_inputs
241
+ own_property = "Property" in bool_inputs
242
+ work_phone = "Work phone" in bool_inputs
243
+ phone = "Phone" in bool_inputs
244
+ email = "Email" in bool_inputs
245
+
246
+ user_inputs = pandas.DataFrame({
247
+ "Gender": [gender],
248
+ "Own_car": [own_car],
249
+ "Own_property": [own_property],
250
+ "Work_phone": [work_phone],
251
+ "Phone": [phone],
252
+ "Email": [email],
253
+ "Num_children": num_children,
254
+ "Num_family": num_family,
255
+ "Total_income": total_income,
256
+ "Age": age,
257
+ "Income_type": income_type,
258
+ "Education_type": education_type,
259
+ "Family_status": family_status,
260
+ "Occupation_type": occupation_type,
261
+ "Housing_type": housing_type,
262
+ })
263
+
264
+ preprocessed_user_inputs = PRE_PROCESSOR_USER.transform(user_inputs)
265
+
266
+ return _keygen_encrypt_send(preprocessed_user_inputs, "user")
267
+
268
+
269
+ def pre_process_keygen_encrypt_send_bank(*inputs):
270
+ """Pre-process the given inputs for a specific client.
271
+
272
+ Args:
273
+ *inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
274
+
275
+ Returns:
276
+ (int, bytes): Integer ID representing the current client and a byte short representation of
277
+ the encrypted input to send.
278
+ """
279
+ account_length = inputs[0]
280
+
281
+ return _keygen_encrypt_send(account_length, "bank")
282
+
283
+
284
+ def pre_process_keygen_encrypt_send_third_party(*inputs):
285
+ """Pre-process the given inputs for a specific client.
286
+
287
+ Args:
288
+ *inputs (Tuple[numpy.ndarray]): The inputs to pre-process.
289
+
290
+ Returns:
291
+ (int, bytes): Integer ID representing the current client and a byte short representation of
292
+ the encrypted input to send.
293
+ """
294
+ employed, years_employed = inputs
295
+
296
+ # Original dataset contains an "unemployed" feature instead of "employed"
297
+ unemployed = employed == "No"
298
+
299
+ third_party_inputs = pandas.DataFrame({
300
+ "Unemployed": [unemployed],
301
+ "Years_employed": [years_employed],
302
+ })
303
+
304
+ preprocessed_third_party_inputs = PRE_PROCESSOR_THIRD_PARTY.transform(third_party_inputs)
305
+
306
+ return _keygen_encrypt_send(preprocessed_third_party_inputs, "third_party")
307
+
308
+
309
+ def run_fhe(user_id, bank_id, third_party_id):
310
+ """Run the model on the encrypted inputs previously sent using FHE.
311
+
312
+ Args:
313
+ user_id (int): The user ID to consider.
314
+ bank_id (int): The bank ID to consider.
315
+ third_party_id (int): The third party ID to consider.
316
+ """
317
+
318
+ # TODO : add a warning for users to send all client types' inputs
319
+
320
+ data = {
321
+ "user_id": user_id,
322
+ "bank_id": bank_id,
323
+ "third_party_id": third_party_id,
324
+ }
325
+
326
+ # Trigger the FHE execution on the encrypted inputs previously sent
327
+ url = SERVER_URL + "run_fhe"
328
+ with requests.post(
329
+ url=url,
330
+ data=data,
331
+ ) as response:
332
+ if response.ok:
333
+ return response.json()
334
+ else:
335
+ raise gr.Error("Please wait for the inputs to be sent to the server.")
336
+
337
+
338
+ def get_output(user_id, bank_id, third_party_id):
339
+ """Retrieve the encrypted output.
340
+
341
+ Args:
342
+ user_id (int): The user ID to consider.
343
+ bank_id (int): The bank ID to consider.
344
+ third_party_id (int): The third party ID to consider.
345
+ """
346
+ data = {
347
+ "user_id": user_id,
348
+ "bank_id": bank_id,
349
+ "third_party_id": third_party_id,
350
+ }
351
+
352
+ # Retrieve the encrypted output image
353
+ url = SERVER_URL + "get_output"
354
+ with requests.post(
355
+ url=url,
356
+ data=data,
357
+ ) as response:
358
+ if response.ok:
359
+ encrypted_output = response.content
360
+
361
+ # Save the encrypted output to bytes in a file as it is too large to pass through regular
362
+ # Gradio buttons (see https://github.com/gradio-app/gradio/issues/1877)
363
+ # TODO : check if output to user is relevant
364
+ encrypted_output_path = _get_client_file_path("encrypted_output", user_id + bank_id + third_party_id, "output")
365
+
366
+ with encrypted_output_path.open("wb") as encrypted_output_file:
367
+ encrypted_output_file.write(encrypted_output)
368
+
369
+ # TODO
370
+ # Decrypt the output using a different (wrong) key for display
371
+ # output_encrypted_representation = decrypt_output_with_wrong_key(encrypted_output, client_type)
372
+
373
+ # return output_encrypted_representation
374
+
375
+ return None
376
+ else:
377
+ raise gr.Error("Please wait for the FHE execution to be completed.")
378
+
379
+
380
+ def decrypt_output(user_id, bank_id, third_party_id):
381
+ """Decrypt the result.
382
+
383
+ Args:
384
+ user_id (int): The user ID to consider.
385
+ bank_id (int): The bank ID to consider.
386
+ third_party_id (int): The third party ID to consider.
387
+
388
+ Returns:
389
+ output(numpy.ndarray): The decrypted output
390
+
391
+ """
392
+ # Get the encrypted output path
393
+ encrypted_output_path = _get_client_file_path("encrypted_output", user_id + bank_id + third_party_id, "output")
394
+
395
+ if not encrypted_output_path.is_file():
396
+ raise gr.Error("Please run the FHE execution first.")
397
+
398
+ # Load the encrypted output as bytes
399
+ with encrypted_output_path.open("rb") as encrypted_output_file:
400
+ encrypted_output_proba = encrypted_output_file.read()
401
+
402
+ # Retrieve the client API
403
+ client = _get_client(user_id, "user")
404
+
405
+ # Deserialize, decrypt and post-process the encrypted output
406
+ output_proba = client.deserialize_decrypt_dequantize(encrypted_output_proba)
407
+
408
+ # Determine the predicted class
409
+ output = numpy.argmax(output_proba, axis=1)
410
+
411
+ return output
data/clean_data.csv CHANGED
The diff for this file is too large to render. See raw diff
 
deployment_files/client.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4b42d1dff3521c2e7462994c6eafb072bf004108d27c838e690a6702d775c0b5
3
- size 35673
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06c7bd8264089eb169342aa5c3f638b11d894c54d054511a91523bfdfab69487
3
+ size 76130
deployment_files/pre_processor_third_party.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee39c00c8ca119a4e61f6905687c9bb540352b5ce4005aaba125290679722587
3
+ size 1590
deployment_files/pre_processor_user.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af3db3f40e0e38febb8efb858e07df1f432458cc66f2edb38bedbd4d35520802
3
+ size 6207
deployment_files/server.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:75b15663431ff4f3788b380c100ea87c1bf97959234aeefb51ae734bed7514c4
3
- size 10953
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04c3f1de7261abe6ad075f6cc13885677ddf4ca0b03d6a31f26a60f94d5aa2ae
3
+ size 10975
development.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Train and compile the model."""
2
+
3
+ import shutil
4
+ import numpy
5
+ import pandas
6
+ import pickle
7
+
8
+ from sklearn.model_selection import train_test_split
9
+ from sklearn.metrics import accuracy_score
10
+ from imblearn.over_sampling import SMOTE
11
+
12
+ from settings import DEPLOYMENT_PATH, RANDOM_STATE, DATA_PATH, INPUT_SLICES, PRE_PROCESSOR_USER_PATH, PRE_PROCESSOR_THIRD_PARTY_PATH
13
+ from utils.client_server_interface import MultiInputsFHEModelDev
14
+ from utils.model import MultiInputXGBClassifier
15
+ from utils.pre_processing import get_pre_processors, select_and_pop_features
16
+
17
+
18
+ def get_processed_multi_inputs(data):
19
+ return (
20
+ data[:, INPUT_SLICES["user"]],
21
+ data[:, INPUT_SLICES["bank"]],
22
+ data[:, INPUT_SLICES["third_party"]]
23
+ )
24
+
25
+ print("Load and pre-process the data")
26
+
27
+ data = pandas.read_csv(DATA_PATH, encoding="utf-8")
28
+
29
+ # Define input and target data
30
+ data_y = data.pop("Target").copy()
31
+ data_x = data.copy()
32
+
33
+ # Get data from all parties
34
+ data_third_party = select_and_pop_features(data_x, ["Years_employed", "Unemployed"])
35
+ data_bank = select_and_pop_features(data_x, ["Account_length"])
36
+ data_user = data_x.copy()
37
+
38
+ # Feature engineer the data
39
+ pre_processor_user, pre_processor_third_party = get_pre_processors()
40
+
41
+ preprocessed_data_user = pre_processor_user.fit_transform(data_user)
42
+ preprocessed_data_bank = data_bank.to_numpy()
43
+ preprocessed_data_third_party = pre_processor_third_party.fit_transform(data_third_party)
44
+
45
+ preprocessed_data_x = numpy.concatenate((preprocessed_data_user, preprocessed_data_bank, preprocessed_data_third_party), axis=1)
46
+
47
+ # The initial data-set is very imbalanced: use SMOTE to get better results
48
+ x, y = SMOTE().fit_resample(preprocessed_data_x, data_y)
49
+
50
+ # Retrieve the training and testing data
51
+ X_train, X_test, y_train, y_test = train_test_split(
52
+ x, y, stratify=y, test_size=0.3, random_state=RANDOM_STATE
53
+ )
54
+
55
+
56
+ print("\nTrain and compile the model")
57
+
58
+ model = MultiInputXGBClassifier(max_depth=3, n_estimators=40)
59
+
60
+ model, sklearn_model = model.fit_benchmark(X_train, y_train)
61
+
62
+ multi_inputs_train = get_processed_multi_inputs(X_train)
63
+
64
+ model.compile(*multi_inputs_train, inputs_encryption_status=["encrypted", "encrypted", "encrypted"])
65
+
66
+ # Delete the deployment folder and its content if it already exists
67
+ if DEPLOYMENT_PATH.is_dir():
68
+ shutil.rmtree(DEPLOYMENT_PATH)
69
+
70
+
71
+ print("\nEvaluate the models")
72
+
73
+ y_pred_sklearn = sklearn_model.predict(X_test)
74
+
75
+ print(f"Sklearn accuracy score : {accuracy_score(y_test, y_pred_sklearn )*100:.2f}%")
76
+
77
+ multi_inputs_test = get_processed_multi_inputs(X_test)
78
+
79
+ y_pred_simulated = model.predict_multi_inputs(*multi_inputs_test, simulate=True)
80
+
81
+ print(f"Concrete ML accuracy score (simulated) : {accuracy_score(y_test, y_pred_simulated)*100:.2f}%")
82
+
83
+
84
+ print("\nSave deployment files")
85
+
86
+ # Save files needed for deployment
87
+ fhe_dev = MultiInputsFHEModelDev(DEPLOYMENT_PATH, model)
88
+ fhe_dev.save()
89
+
90
+ # Save pre-processors
91
+ with PRE_PROCESSOR_USER_PATH.open('wb') as file:
92
+ pickle.dump(pre_processor_user, file)
93
+
94
+ with PRE_PROCESSOR_THIRD_PARTY_PATH.open('wb') as file:
95
+ pickle.dump(pre_processor_third_party, file)
96
+
97
+ print("\nDone !")
development/development.py DELETED
@@ -1,67 +0,0 @@
1
- "A script to generate all development files necessary for the project."
2
-
3
- import shutil
4
- import numpy
5
- import pandas
6
-
7
- from sklearn.model_selection import train_test_split
8
- from imblearn.over_sampling import SMOTE
9
-
10
- from ..settings import DEPLOYMENT_PATH, RANDOM_STATE
11
- from client_server_interface import MultiInputsFHEModelDev
12
- from model import MultiInputXGBClassifier
13
- from development.pre_processing import pre_process_data
14
-
15
-
16
- print("Load and pre-process the data")
17
-
18
- data = pandas.read_csv("data/clean_data.csv", encoding="utf-8")
19
-
20
- # Make median annual salary similar to France (2023): from 157500 to 22050
21
- data["Total_income"] = data["Total_income"] * 0.14
22
-
23
- # Remove ID feature
24
- data.drop("ID", axis=1, inplace=True)
25
-
26
- # Feature engineer the data
27
- pre_processed_data, training_bins = pre_process_data(data)
28
-
29
- # Define input and target data
30
- y = pre_processed_data.pop("Target")
31
- x = pre_processed_data
32
-
33
- # The initial data-set is very imbalanced: use SMOTE to get better results
34
- x, y = SMOTE().fit_resample(x, y)
35
-
36
- # Retrieve the training data
37
- X_train, _, y_train, _ = train_test_split(
38
- x, y, stratify=y, test_size=0.3, random_state=RANDOM_STATE
39
- )
40
-
41
- # Convert the Pandas data frames into Numpy arrays
42
- X_train_np = X_train.to_numpy()
43
- y_train_np = y_train.to_numpy()
44
-
45
-
46
- print("Train and compile the model")
47
-
48
- model = MultiInputXGBClassifier(max_depth=3, n_estimators=40)
49
-
50
- model.fit(X_train_np, y_train_np)
51
-
52
- multi_inputs_train = numpy.array_split(X_train_np, 3, axis=1)
53
-
54
- model.compile(*multi_inputs_train, inputs_encryption_status=["encrypted", "encrypted", "encrypted"])
55
-
56
- # Delete the deployment folder and its content if it already exists
57
- if DEPLOYMENT_PATH.is_dir():
58
- shutil.rmtree(DEPLOYMENT_PATH)
59
-
60
-
61
- print("Save deployment files")
62
-
63
- # Save the files needed for deployment
64
- fhe_dev = MultiInputsFHEModelDev(model, DEPLOYMENT_PATH)
65
- fhe_dev.save()
66
-
67
- print("Done !")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
development/pre_processing.py DELETED
@@ -1,122 +0,0 @@
1
- import pandas
2
- from copy import deepcopy
3
-
4
-
5
- def convert_dummy(df, feature):
6
- pos = pandas.get_dummies(df[feature], prefix=feature)
7
-
8
- df.drop([feature], axis=1, inplace=True)
9
- df = df.join(pos)
10
- return df
11
-
12
-
13
- def get_category(df, col, labels, qcut=False, binsnum=None, bins=None, retbins=False):
14
- assert binsnum is not None or bins is not None
15
-
16
- if qcut and binsnum is not None:
17
- localdf, bin_edges = pandas.qcut(df[col], q=binsnum, labels=labels, retbins=True) # quantile cut
18
- else:
19
- input_bins = bins if bins is not None else binsnum
20
- localdf, bin_edges = pandas.cut(df[col], bins=input_bins, labels=labels, retbins=True) # equal-length cut
21
-
22
- df.drop(col, axis=1, inplace=True)
23
-
24
- localdf = pandas.DataFrame(localdf)
25
- df = df.join(localdf[col])
26
-
27
- if retbins:
28
- return df, bin_edges
29
-
30
- return df
31
-
32
-
33
- def pre_process_data(input_data, bins=None, columns=None):
34
- assert bins is None or ("bin_edges_income" in bins and "bin_edges_age" in bins and "bin_edges_years_employed" in bins and columns is not None)
35
-
36
- training_bins = {}
37
-
38
- input_data = deepcopy(input_data)
39
- bins = deepcopy(bins) if bins is not None else None
40
-
41
- input_data.loc[input_data["Num_children"] >= 2, "Num_children"] = "2_or_more"
42
-
43
- input_data = convert_dummy(input_data, "Num_children")
44
-
45
- if bins is None:
46
- input_data, bin_edges_income = get_category(input_data, "Total_income", ["low", "medium", "high"], qcut=True, binsnum=3, retbins=True)
47
- training_bins["bin_edges_income"] = bin_edges_income
48
- else:
49
- input_data = get_category(input_data, "Total_income", ["low", "medium", "high"], bins=bins["bin_edges_income"])
50
-
51
- input_data = convert_dummy(input_data, "Total_income")
52
-
53
- if bins is None:
54
- input_data, bin_edges_age = get_category(input_data, "Age", ["lowest", "low", "medium", "high", "highest"], binsnum=5, retbins=True)
55
- training_bins["bin_edges_age"] = bin_edges_age
56
- else:
57
- input_data = get_category(input_data, "Age", ["lowest", "low", "medium", "high", "highest"], bins=bins["bin_edges_age"])
58
-
59
- input_data = convert_dummy(input_data, "Age")
60
-
61
- if bins is None:
62
- input_data, bin_edges_years_employed = get_category(input_data, "Years_employed", ["lowest", "low", "medium", "high", "highest"], binsnum=5, retbins=True)
63
- training_bins["bin_edges_years_employed"] = bin_edges_years_employed
64
- else:
65
- input_data = get_category(input_data, "Years_employed", ["lowest", "low", "medium", "high", "highest"], bins=bins["bin_edges_years_employed"])
66
-
67
- input_data = convert_dummy(input_data, "Years_employed")
68
-
69
- input_data.loc[input_data["Num_family"] >= 3, "Num_family"] = "3_or_more"
70
-
71
- input_data = convert_dummy(input_data, "Num_family")
72
-
73
- input_data.loc[input_data["Income_type"] == "Pensioner", "Income_type"] = "State servant"
74
- input_data.loc[input_data["Income_type"] == "Student", "Income_type"] = "State servant"
75
-
76
- input_data = convert_dummy(input_data, "Income_type")
77
-
78
- input_data.loc[
79
- (input_data["Occupation_type"] == "Cleaning staff")
80
- | (input_data["Occupation_type"] == "Cooking staff")
81
- | (input_data["Occupation_type"] == "Drivers")
82
- | (input_data["Occupation_type"] == "Laborers")
83
- | (input_data["Occupation_type"] == "Low-skill Laborers")
84
- | (input_data["Occupation_type"] == "Security staff")
85
- | (input_data["Occupation_type"] == "Waiters/barmen staff"),
86
- "Occupation_type",
87
- ] = "Labor_work"
88
- input_data.loc[
89
- (input_data["Occupation_type"] == "Accountants")
90
- | (input_data["Occupation_type"] == "Core staff")
91
- | (input_data["Occupation_type"] == "HR staff")
92
- | (input_data["Occupation_type"] == "Medicine staff")
93
- | (input_data["Occupation_type"] == "Private service staff")
94
- | (input_data["Occupation_type"] == "Realty agents")
95
- | (input_data["Occupation_type"] == "Sales staff")
96
- | (input_data["Occupation_type"] == "Secretaries"),
97
- "Occupation_type",
98
- ] = "Office_work"
99
- input_data.loc[
100
- (input_data["Occupation_type"] == "Managers")
101
- | (input_data["Occupation_type"] == "High skill tech staff")
102
- | (input_data["Occupation_type"] == "IT staff"),
103
- "Occupation_type",
104
- ] = "High_tech_work"
105
-
106
- input_data = convert_dummy(input_data, "Occupation_type")
107
-
108
- input_data = convert_dummy(input_data, "Housing_type")
109
-
110
- input_data.loc[input_data["Education_type"] == "Academic degree", "Education_type"] = "Higher education"
111
- input_data = convert_dummy(input_data, "Education_type")
112
-
113
- input_data = convert_dummy(input_data, "Family_status")
114
-
115
- input_data = input_data.astype("int")
116
-
117
- if training_bins:
118
- return input_data, training_bins
119
-
120
- input_data = input_data.reindex(columns=columns, fill_value=0)
121
-
122
- return input_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
server.py CHANGED
@@ -6,12 +6,13 @@ from fastapi import FastAPI, File, Form, UploadFile
6
  from fastapi.responses import JSONResponse, Response
7
 
8
  from settings import DEPLOYMENT_PATH, SERVER_FILES, CLIENT_TYPES
9
- from development.client_server_interface import MultiInputsFHEModelServer
10
 
11
  # Load the server objects related to all currently available filters once and for all
12
  FHE_SERVER = MultiInputsFHEModelServer(DEPLOYMENT_PATH)
13
 
14
- def get_server_file_path(name, client_id, client_type):
 
15
  """Get the correct temporary file path for the server.
16
 
17
  Args:
@@ -42,8 +43,8 @@ def send_input(
42
  ):
43
  """Send the inputs to the server."""
44
  # Retrieve the encrypted inputs and the evaluation key paths
45
- encrypted_inputs_path = get_server_file_path("encrypted_inputs", client_id, client_type)
46
- evaluation_key_path = get_server_file_path("evaluation_key", client_id, client_type)
47
 
48
  # Write the files using the above paths
49
  with encrypted_inputs_path.open("wb") as encrypted_inputs, evaluation_key_path.open(
@@ -55,23 +56,30 @@ def send_input(
55
 
56
  @app.post("/run_fhe")
57
  def run_fhe(
58
- client_id: str = Form(),
 
 
59
  ):
60
  """Execute the model on the encrypted inputs using FHE."""
61
- # Retrieve the evaluation key
62
- evaluation_key_path = get_server_file_path("evaluation_key", client_id, "user")
63
 
64
- # Get the evaluation key
65
- with evaluation_key_path.open("rb") as evaluation_key_file:
 
 
 
 
 
 
 
 
66
  evaluation_key = evaluation_key_file.read()
 
 
 
67
 
68
- # Get the encrypted inputs
69
- encrypted_inputs = []
70
- for client_type in CLIENT_TYPES:
71
- encrypted_inputs_path = get_server_file_path("encrypted_inputs", client_id, client_type)
72
- with encrypted_inputs_path.open("rb") as encrypted_inputs_file:
73
- encrypted_input = encrypted_inputs_file.read()
74
- encrypted_inputs.append(encrypted_input)
75
 
76
  # Run the FHE execution
77
  start = time.time()
@@ -79,7 +87,7 @@ def run_fhe(
79
  fhe_execution_time = round(time.time() - start, 2)
80
 
81
  # Retrieve the encrypted output path
82
- encrypted_output_path = get_server_file_path("encrypted_output", client_id, client_type)
83
 
84
  # Write the file using the above path
85
  with encrypted_output_path.open("wb") as output_file:
@@ -90,12 +98,13 @@ def run_fhe(
90
 
91
  @app.post("/get_output")
92
  def get_output(
93
- client_id: str = Form(),
94
- client_type: str = Form(),
 
95
  ):
96
  """Retrieve the encrypted output."""
97
  # Retrieve the encrypted output path
98
- encrypted_output_path = get_server_file_path("encrypted_output", client_id, client_type)
99
 
100
  # Read the file using the above path
101
  with encrypted_output_path.open("rb") as encrypted_output_file:
 
6
  from fastapi.responses import JSONResponse, Response
7
 
8
  from settings import DEPLOYMENT_PATH, SERVER_FILES, CLIENT_TYPES
9
+ from utils.client_server_interface import MultiInputsFHEModelServer
10
 
11
  # Load the server objects related to all currently available filters once and for all
12
  FHE_SERVER = MultiInputsFHEModelServer(DEPLOYMENT_PATH)
13
 
14
+
15
+ def _get_server_file_path(name, client_id, client_type):
16
  """Get the correct temporary file path for the server.
17
 
18
  Args:
 
43
  ):
44
  """Send the inputs to the server."""
45
  # Retrieve the encrypted inputs and the evaluation key paths
46
+ encrypted_inputs_path = _get_server_file_path("encrypted_inputs", client_id, client_type)
47
+ evaluation_key_path = _get_server_file_path("evaluation_key", client_id, client_type)
48
 
49
  # Write the files using the above paths
50
  with encrypted_inputs_path.open("wb") as encrypted_inputs, evaluation_key_path.open(
 
56
 
57
  @app.post("/run_fhe")
58
  def run_fhe(
59
+ user_id: str = Form(),
60
+ bank_id: str = Form(),
61
+ third_party_id: str = Form(),
62
  ):
63
  """Execute the model on the encrypted inputs using FHE."""
64
+ # Retrieve the evaluation key (from the user, as all evaluation keys should be the same)
65
+ evaluation_key_path = _get_server_file_path("evaluation_key", user_id, "user")
66
 
67
+ # Get the encrypted inputs
68
+ encrypted_user_inputs_path = _get_server_file_path("encrypted_inputs", user_id, "user")
69
+ encrypted_bank_inputs_path = _get_server_file_path("encrypted_inputs", bank_id, "bank")
70
+ encrypted_third_party_inputs_path = _get_server_file_path("encrypted_inputs", third_party_id, "third_party")
71
+ with (
72
+ evaluation_key_path.open("rb") as evaluation_key_file,
73
+ encrypted_user_inputs_path.open("rb") as encrypted_user_inputs_file,
74
+ encrypted_bank_inputs_path.open("rb") as encrypted_bank_inputs_file,
75
+ encrypted_third_party_inputs_path.open("rb") as encrypted_third_party_inputs_file,
76
+ ):
77
  evaluation_key = evaluation_key_file.read()
78
+ encrypted_user_input = encrypted_user_inputs_file.read()
79
+ encrypted_bank_input = encrypted_bank_inputs_file.read()
80
+ encrypted_third_party_input = encrypted_third_party_inputs_file.read()
81
 
82
+ encrypted_inputs = (encrypted_user_input, encrypted_bank_input, encrypted_third_party_input)
 
 
 
 
 
 
83
 
84
  # Run the FHE execution
85
  start = time.time()
 
87
  fhe_execution_time = round(time.time() - start, 2)
88
 
89
  # Retrieve the encrypted output path
90
+ encrypted_output_path = _get_server_file_path("encrypted_output", user_id + bank_id + third_party_id, "output")
91
 
92
  # Write the file using the above path
93
  with encrypted_output_path.open("wb") as output_file:
 
98
 
99
  @app.post("/get_output")
100
  def get_output(
101
+ user_id: str = Form(),
102
+ bank_id: str = Form(),
103
+ third_party_id: str = Form(),
104
  ):
105
  """Retrieve the encrypted output."""
106
  # Retrieve the encrypted output path
107
+ encrypted_output_path = _get_server_file_path("encrypted_output", user_id + bank_id + third_party_id, "output")
108
 
109
  # Read the file using the above path
110
  with encrypted_output_path.open("rb") as encrypted_output_file:
settings.py CHANGED
@@ -1,6 +1,7 @@
1
  "All constants used in the project."
2
 
3
  from pathlib import Path
 
4
 
5
  # The directory of this project
6
  REPO_DIR = Path(__file__).parent
@@ -11,6 +12,10 @@ FHE_KEYS = REPO_DIR / ".fhe_keys"
11
  CLIENT_FILES = REPO_DIR / "client_files"
12
  SERVER_FILES = REPO_DIR / "server_files"
13
 
 
 
 
 
14
  # Create the necessary directories
15
  FHE_KEYS.mkdir(exist_ok=True)
16
  CLIENT_FILES.mkdir(exist_ok=True)
@@ -19,8 +24,14 @@ SERVER_FILES.mkdir(exist_ok=True)
19
  # Store the server's URL
20
  SERVER_URL = "http://localhost:8000/"
21
 
22
- RANDOM_STATE = 0
 
 
 
 
23
 
 
 
24
  INITIAL_INPUT_SHAPE = (1, 49)
25
 
26
  CLIENT_TYPES = ["user", "bank", "third_party"]
@@ -29,8 +40,33 @@ INPUT_INDEXES = {
29
  "bank": 1,
30
  "third_party": 2,
31
  }
32
- START_POSITIONS = {
33
- "user": 0, # First position: start from 0
34
- "bank": 17, # Second position: start from len(input_user)
35
- "third_party": 33, # Third position: start from len(input_user) + len(input_bank)
36
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  "All constants used in the project."
2
 
3
  from pathlib import Path
4
+ import pandas
5
 
6
  # The directory of this project
7
  REPO_DIR = Path(__file__).parent
 
12
  CLIENT_FILES = REPO_DIR / "client_files"
13
  SERVER_FILES = REPO_DIR / "server_files"
14
 
15
+ # Path targeting pre-processor saved files
16
+ PRE_PROCESSOR_USER_PATH = DEPLOYMENT_PATH / 'pre_processor_user.pkl'
17
+ PRE_PROCESSOR_THIRD_PARTY_PATH = DEPLOYMENT_PATH / 'pre_processor_third_party.pkl'
18
+
19
  # Create the necessary directories
20
  FHE_KEYS.mkdir(exist_ok=True)
21
  CLIENT_FILES.mkdir(exist_ok=True)
 
24
  # Store the server's URL
25
  SERVER_URL = "http://localhost:8000/"
26
 
27
+ # Path to data file
28
+ # The data was previously cleaned using this notebook : https://www.kaggle.com/code/samuelcortinhas/credit-cards-data-cleaning
29
+ # Additionally, the "ID" columns has been removed and the "Total_income" has been adjusted so that
30
+ # its median value corresponds to France's 2023 median annual salary (22050 euros)
31
+ DATA_PATH = "data/clean_data.csv"
32
 
33
+ # Developement settings
34
+ RANDOM_STATE = 0
35
  INITIAL_INPUT_SHAPE = (1, 49)
36
 
37
  CLIENT_TYPES = ["user", "bank", "third_party"]
 
40
  "bank": 1,
41
  "third_party": 2,
42
  }
43
+ INPUT_SLICES = {
44
+ "user": slice(0, 42), # First position: start from 0
45
+ "bank": slice(42, 43), # Second position: start from n_feature_user
46
+ "third_party": slice(43, 49), # Third position: start from n_feature_user + n_feature_bank
47
  }
48
+
49
+ _data = pandas.read_csv(DATA_PATH, encoding="utf-8")
50
+
51
+ def get_min_max(data, column):
52
+ """Get min/max values of a column in order to input them in Gradio's API as key arguments."""
53
+ return {
54
+ "minimum": int(data[column].min()),
55
+ "maximum": int(data[column].max()),
56
+ }
57
+
58
+ # App data min and max values
59
+ ACCOUNT_MIN_MAX = get_min_max(_data, "Account_length")
60
+ CHILDREN_MIN_MAX = get_min_max(_data, "Num_children")
61
+ INCOME_MIN_MAX = get_min_max(_data, "Total_income")
62
+ AGE_MIN_MAX = get_min_max(_data, "Age")
63
+ EMPLOYED_MIN_MAX = get_min_max(_data, "Years_employed")
64
+ FAMILY_MIN_MAX = get_min_max(_data, "Num_family")
65
+
66
+ # App data choices
67
+ INCOME_TYPES = list(_data["Income_type"].unique())
68
+ OCCUPATION_TYPES = list(_data["Occupation_type"].unique())
69
+ HOUSING_TYPES = list(_data["Housing_type"].unique())
70
+ EDUCATION_TYPES = list(_data["Education_type"].unique())
71
+ FAMILY_STATUS = list(_data["Family_status"].unique())
72
+
{development β†’ utils}/client_server_interface.py RENAMED
@@ -1,3 +1,5 @@
 
 
1
  import numpy
2
  import copy
3
 
@@ -25,22 +27,21 @@ class MultiInputsFHEModelClient(FHEModelClient):
25
 
26
  super().__init__(*args, **kwargs)
27
 
28
- def quantize_encrypt_serialize_multi_inputs(self, x: numpy.ndarray, input_index, initial_input_shape, start_position) -> bytes:
29
 
30
  x_padded = numpy.zeros(initial_input_shape)
31
 
32
- end = start_position + x.shape[1]
33
- x_padded[:, start_position:end] = x
34
 
35
  q_x_padded = self.model.quantize_input(x_padded)
36
 
37
- q_x = q_x_padded[:, start_position:end]
38
 
39
- q_x_padded = [None for _ in range(self.nb_inputs)]
40
- q_x_padded[input_index] = q_x
41
 
42
  # Encrypt the values
43
- q_x_enc = self.client.encrypt(*q_x_padded)
44
 
45
  # Serialize the encrypted values to be sent to the server
46
  q_x_enc_ser = q_x_enc[input_index].serialize()
 
1
+ """Modified classes for use for Client-Server interface with multi-inputs circuits."""
2
+
3
  import numpy
4
  import copy
5
 
 
27
 
28
  super().__init__(*args, **kwargs)
29
 
30
+ def quantize_encrypt_serialize_multi_inputs(self, x: numpy.ndarray, input_index, initial_input_shape, input_slice) -> bytes:
31
 
32
  x_padded = numpy.zeros(initial_input_shape)
33
 
34
+ x_padded[:, input_slice] = x
 
35
 
36
  q_x_padded = self.model.quantize_input(x_padded)
37
 
38
+ q_x = q_x_padded[:, input_slice]
39
 
40
+ q_x_inputs = [None for _ in range(self.nb_inputs)]
41
+ q_x_inputs[input_index] = q_x
42
 
43
  # Encrypt the values
44
+ q_x_enc = self.client.encrypt(*q_x_inputs)
45
 
46
  # Serialize the encrypted values to be sent to the server
47
  q_x_enc_ser = q_x_enc[input_index].serialize()
{development β†’ utils}/model.py RENAMED
@@ -1,4 +1,7 @@
 
 
1
  import numpy
 
2
  from typing import Optional, Sequence, Union
3
 
4
  from concrete.fhe.compilation.compiler import Compiler, Configuration, DebugArtifacts, Circuit
@@ -128,3 +131,43 @@ class MultiInputXGBClassifier(ConcreteXGBClassifier):
128
  )
129
 
130
  return compiler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified model class to handles multi-inputs circuit."""
2
+
3
  import numpy
4
+ import time
5
  from typing import Optional, Sequence, Union
6
 
7
  from concrete.fhe.compilation.compiler import Compiler, Configuration, DebugArtifacts, Circuit
 
131
  )
132
 
133
  return compiler
134
+
135
+ def predict_multi_inputs(self, *multi_inputs, simulate=True):
136
+ """Run the inference with multiple inputs, with simulation or in FHE."""
137
+ assert all(isinstance(inputs, numpy.ndarray) for inputs in multi_inputs)
138
+
139
+ if not simulate:
140
+ self.fhe_circuit.keygen()
141
+
142
+ y_preds = []
143
+ execution_times = []
144
+ for inputs in zip(*multi_inputs):
145
+ inputs = tuple(numpy.expand_dims(input, axis=0) for input in inputs)
146
+
147
+ q_inputs = self.quantize_input(*inputs)
148
+
149
+ if simulate:
150
+ q_y_proba = self.fhe_circuit.simulate(*q_inputs)
151
+ else:
152
+ q_inputs_enc = self.fhe_circuit.encrypt(*q_inputs)
153
+
154
+ start = time.time()
155
+ q_y_proba_enc = self.fhe_circuit.run(*q_inputs_enc)
156
+ end = time.time() - start
157
+
158
+ execution_times.append(end)
159
+
160
+ q_y_proba = self.fhe_circuit.decrypt(q_y_proba_enc)
161
+
162
+ y_proba = self.dequantize_output(q_y_proba)
163
+
164
+ y_proba = self.post_processing(y_proba)
165
+
166
+ y_pred = numpy.argmax(y_proba, axis=1)
167
+
168
+ y_preds.append(y_pred)
169
+
170
+ if not simulate:
171
+ print(f"FHE execution time per inference: {numpy.mean(execution_times) :.2}s")
172
+
173
+ return numpy.array(y_preds)
utils/pre_processing.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data pre-processing functions."""
2
+
3
+ import numpy
4
+ from sklearn.compose import ColumnTransformer
5
+ from sklearn.pipeline import Pipeline
6
+ from sklearn.preprocessing import OneHotEncoder, FunctionTransformer, KBinsDiscretizer
7
+
8
+
9
+ def _get_pipeline_replace_one_hot(func, value):
10
+ return Pipeline([
11
+ ("replace", FunctionTransformer(
12
+ func,
13
+ kw_args={"value": value},
14
+ feature_names_out='one-to-one',
15
+ )),
16
+ ("one_hot", OneHotEncoder(),),
17
+ ])
18
+
19
+
20
+ def _replace_values_geq(column, value):
21
+ return numpy.where(column >= value, f"{value}_or_more", column)
22
+
23
+ def _replace_values_eq(column, value):
24
+ for desired_value, values_to_replace in value.items():
25
+ column = numpy.where(numpy.isin(column, values_to_replace), desired_value, column)
26
+ return column
27
+
28
+ def get_pre_processors():
29
+ pre_processor_user = ColumnTransformer(
30
+ transformers=[
31
+ (
32
+ "replace_num_children",
33
+ _get_pipeline_replace_one_hot(_replace_values_geq, 2),
34
+ ['Num_children']
35
+ ),
36
+ (
37
+ "replace_num_family",
38
+ _get_pipeline_replace_one_hot(_replace_values_geq, 3),
39
+ ['Num_family']
40
+ ),
41
+ (
42
+ "replace_income_type",
43
+ _get_pipeline_replace_one_hot(_replace_values_eq, {"State servant": ["Pensioner", "Student"]}),
44
+ ['Income_type']
45
+ ),
46
+ (
47
+ "replace_education_type",
48
+ _get_pipeline_replace_one_hot(_replace_values_eq, {"Higher education": ["Academic degree"]}),
49
+ ['Education_type']
50
+ ),
51
+ (
52
+ "replace_occupation_type_labor",
53
+ _get_pipeline_replace_one_hot(
54
+ _replace_values_eq,
55
+ {
56
+ "Labor_work": ["Cleaning staff", "Cooking staff", "Drivers", "Laborers", "Low-skill Laborers", "Security staff", "Waiters/barmen staff"],
57
+ "Office_work": ["Accountants", "Core staff", "HR staff", "Medicine staff", "Private service staff", "Realty agents", "Sales staff", "Secretaries"],
58
+ "High_tech_work": ["Managers", "High skill tech staff", "IT staff"],
59
+ },
60
+ ),
61
+ ['Occupation_type']
62
+ ),
63
+ ('one_hot_housing_fam_status', OneHotEncoder(), ['Housing_type', 'Family_status']),
64
+ ('qbin_total_income', KBinsDiscretizer(n_bins=3, strategy='quantile', encode="onehot"), ['Total_income']),
65
+ ('bin_age', KBinsDiscretizer(n_bins=5, strategy='uniform', encode="onehot"), ['Age']),
66
+ ],
67
+ remainder='passthrough',
68
+ verbose_feature_names_out=False,
69
+ )
70
+
71
+ pre_processor_third_party = ColumnTransformer(
72
+ transformers=[
73
+ ('bin_years_employed', KBinsDiscretizer(n_bins=5, strategy='uniform', encode="onehot"), ['Years_employed'])
74
+ ],
75
+ remainder='passthrough',
76
+ verbose_feature_names_out=False,
77
+ )
78
+
79
+ return pre_processor_user, pre_processor_third_party
80
+
81
+
82
+ def select_and_pop_features(data, columns):
83
+ new_data = data[columns].copy()
84
+ data.drop(columns, axis=1, inplace=True)
85
+ return new_data