kcelia commited on
Commit
13fb76e
1 Parent(s): 4b6bfbb

chore: update

Browse files
ConcreteXGBoostClassifier.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e95ed531c5b4fd4d23330dee9a72884979d93c4be55c1cc25e06efe027253ce4
3
+ size 599833
app.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle as pkl
2
+ import shutil
3
+ from pathlib import Path
4
+ from time import time
5
+ from typing import List, Tuple, Union
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import pandas as pd
10
+ from sklearn import metrics, preprocessing
11
+ from sklearn.ensemble import RandomForestClassifier as SklearnRandomForestClassifier
12
+ from sklearn.model_selection import train_test_split
13
+
14
+ from concrete.ml.common.serialization.loaders import load, loads
15
+ from concrete.ml.deployment import FHEModelClient, FHEModelDev, FHEModelServer
16
+ from concrete.ml.sklearn import XGBClassifier as ConcreteXGBoostClassifier
17
+
18
+ path_to_model = Path("./client_folder").resolve()
19
+
20
+ import subprocess
21
+
22
+ from preprocessing import ( # pylint: disable=wrong-import-position, no-name-in-module
23
+ map_prediction,
24
+ pretty_print,
25
+ )
26
+ from symptoms_categories import SYMPTOMS_LIST
27
+
28
+ ENCRYPTED_DATA_BROWSER_LIMIT = 500
29
+ # This repository's directory
30
+ REPO_DIR = Path(__file__).parent
31
+
32
+ print(f"{REPO_DIR=}")
33
+ # subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
34
+ # time.sleep(3)
35
+
36
+
37
+ def load_data():
38
+ # Load data
39
+ df_train = pd.read_csv("./data/Training_preprocessed.csv")
40
+ df_test = pd.read_csv("./data/Testing_preprocessed.csv")
41
+
42
+ # Separate the traget from the training set
43
+ # df['prognosis] contains the name of the disease
44
+ # df['y] contains the numeric label of the disease
45
+
46
+ y_train = df_train["y"]
47
+ X_train = df_train.drop(columns=["y", "prognosis"], axis=1, errors="ignore")
48
+
49
+ y_test = df_train["y"]
50
+ X_test = df_test.drop(columns=["y", "prognosis"], axis=1, errors="ignore")
51
+
52
+ return (df_train, X_train, X_test), (df_test, y_train, y_test)
53
+
54
+
55
+ def load_model(X_train, y_train):
56
+ concrete_args = {"max_depth": 1, "n_bits": 3, "n_estimators": 3, "n_jobs": -1}
57
+ classifier = ConcreteXGBoostClassifier(**concrete_args)
58
+ classifier.fit(X_train, y_train)
59
+ circuit = classifier.compile(X_train)
60
+
61
+ return classifier, circuit
62
+
63
+
64
+ def key_gen():
65
+
66
+ # Key serialization
67
+ user_id = np.random.randint(0, 2**32)
68
+
69
+ client = FHEModelClient(path_dir=path_to_model, key_dir=f".fhe_keys/{user_id}")
70
+ client.load()
71
+
72
+ # The client first need to create the private and evaluation keys.
73
+
74
+ client.generate_private_and_evaluation_keys()
75
+
76
+ # Get the serialized evaluation keys
77
+ serialized_evaluation_keys = client.get_serialized_evaluation_keys()
78
+ assert isinstance(serialized_evaluation_keys, bytes)
79
+
80
+ np.save(f".fhe_keys/{user_id}/eval_key.npy", serialized_evaluation_keys)
81
+
82
+ serialized_evaluation_keys_shorten = list(serialized_evaluation_keys)[:200]
83
+ serialized_evaluation_keys_shorten_hex = "".join(
84
+ f"{i:02x}" for i in serialized_evaluation_keys_shorten
85
+ )
86
+ # Evaluation keys can be quite large files but only have to be shared once with the server.
87
+
88
+ # Check the size of the evaluation keys (in MB)
89
+ return [
90
+ serialized_evaluation_keys_shorten_hex,
91
+ user_id,
92
+ f"{len(serialized_evaluation_keys) / (10**6):.2f} MB",
93
+ ]
94
+
95
+
96
+ def encode_quantize_encrypt(user_symptoms, user_id):
97
+ # check if the key has been generated
98
+ client = FHEModelClient(path_dir=path_to_model, key_dir=f".fhe_keys/{user_id}")
99
+ client.load()
100
+
101
+ user_symptoms = np.fromstring(user_symptoms[2:-2], dtype=int, sep=".").reshape(1, -1)
102
+
103
+ quant_user_symptoms = client.model.quantize_input(user_symptoms)
104
+ encrypted_quantized_user_symptoms = client.quantize_encrypt_serialize(user_symptoms)
105
+
106
+ # print(client.model.predict(vect_x, fhe="simulate"), client.model.predict(vect_x, fhe="execute"))
107
+ # pred_s = client.model.fhe_circuit.simulate(quant_vect)
108
+ # pred_fhe = client.model.fhe_circuit.encrypt_run_decrypt(quant_vect) #
109
+ # non alpha -> \X1124, base64 ou en exa
110
+
111
+ # Compute size
112
+
113
+ np.save(f".fhe_keys/{user_id}/encrypted_quant_vect.npy", encrypted_quantized_user_symptoms)
114
+
115
+ encrypted_quantized_encoding_shorten = list(encrypted_quantized_user_symptoms)[:200]
116
+ encrypted_quantized_encoding_shorten_hex = "".join(
117
+ f"{i:02x}" for i in encrypted_quantized_encoding_shorten
118
+ )
119
+
120
+ return user_symptoms, quant_user_symptoms, encrypted_quantized_encoding_shorten_hex
121
+
122
+
123
+ def decrypt_prediction(encrypted_quantized_vect, user_id):
124
+ fhe_api = FHEModelClient(path_dir=path_to_model, key_dir=f".fhe_keys/{user_id}")
125
+ fhe_api.load()
126
+ fhe_api.generate_private_and_evaluation_keys(force=False)
127
+ predictions = fhe_api.deserialize_decrypt_dequantize(encrypted_quantized_vect)
128
+ return predictions
129
+
130
+
131
+ def get_user_vect_symptoms_from_checkboxgroup(*user_symptoms) -> np.array:
132
+ symptoms_vector = {key: 0 for key in valid_columns}
133
+
134
+ for symptom_box in user_symptoms:
135
+ for pretty_symptom in symptom_box:
136
+ symptom = "_".join((pretty_symptom.lower().split(" ")))
137
+ if symptom not in symptoms_vector.keys():
138
+ raise KeyError(
139
+ f"The symptom '{symptom}' you provided is not recognized as a valid "
140
+ f"symptom.\nHere is the list of valid symptoms: {symptoms_vector}"
141
+ )
142
+ symptoms_vector[symptom] = 1.0
143
+
144
+ user_symptoms_vect = np.fromiter(symptoms_vector.values(), dtype=float)[np.newaxis, :]
145
+
146
+ assert all(value == 0 or value == 1 for value in user_symptoms_vect.flatten())
147
+
148
+ return user_symptoms_vect
149
+
150
+
151
+ def get_user_vect_symptoms_from_default_disease(disease):
152
+
153
+ user_symptom_vector = df_test[df_test["prognosis"] == disease].iloc[0].values
154
+
155
+ user_symptoms_vect = np.fromiter(user_symptom_vector[:-2], dtype=float)[np.newaxis, :]
156
+
157
+ assert all(value == 0 or value == 1 for value in user_symptoms_vect.flatten())
158
+
159
+ return user_symptoms_vect
160
+
161
+
162
+ def get_user_symptoms_from_default_disease(disease):
163
+ df_filtred = df_test[df_test["prognosis"] == disease]
164
+ columns_with_1 = df_filtred.columns[df_filtred.eq(1).any()].to_list()
165
+ return pretty_print(columns_with_1)
166
+
167
+
168
+ def get_user_symptoms_vector(selected_default_disease, *selected_symptoms):
169
+
170
+ if any(lst for lst in selected_symptoms if lst) and (
171
+ selected_default_disease is not None and len(selected_default_disease) > 0
172
+ ):
173
+ # If the user has already selected a disease and added more symptoms, raise an error
174
+ if set(pretty_print(selected_symptoms)) - set(
175
+ get_user_symptoms_from_default_disease(selected_default_disease)
176
+ ):
177
+ return {
178
+ user_vector_textbox: gr.update(value="An error occurs"),
179
+ error_box: gr.update(
180
+ visible=True, value="Enter a default disease or select your own symptoms"
181
+ ),
182
+ }
183
+ # If the user has not selected a default disease or symptoms, an error is raised.
184
+ if not any(lst for lst in selected_symptoms if lst) and (
185
+ selected_default_disease is None
186
+ or (selected_default_disease is not None and len(selected_default_disease) < 1)
187
+ ):
188
+ return {
189
+ user_vector_textbox: gr.update(value="An error occurs"),
190
+ error_box: gr.update(
191
+ visible=True, value="Enter a default disease or select your own symptoms"
192
+ ),
193
+ }
194
+ # Case 1: The user has checked his own symptoms
195
+ if any(lst for lst in selected_symptoms if lst):
196
+ return {
197
+ user_vector_textbox: get_user_vect_symptoms_from_checkboxgroup(*selected_symptoms),
198
+ }
199
+
200
+ # Case 2: The user has selected a default disease
201
+ if selected_default_disease is not None and len(selected_default_disease) > 0:
202
+ return {
203
+ user_vector_textbox: get_user_vect_symptoms_from_default_disease(
204
+ selected_default_disease
205
+ ),
206
+ error_box: gr.update(visible=False),
207
+ **{
208
+ box: get_user_symptoms_from_default_disease(selected_default_disease)
209
+ for box in check_boxes
210
+ },
211
+ }
212
+
213
+
214
+ def clear_all_buttons():
215
+ return {
216
+ user_id_textbox: None,
217
+ eval_key_textbox: None,
218
+ eval_key_len_textbox: None,
219
+ user_vector_textbox: None,
220
+ box_default: None,
221
+ error_box: gr.update(visible=False),
222
+ **{box: None for box in check_boxes},
223
+ }
224
+
225
+
226
+ if __name__ == "__main__":
227
+ print("Starting demo ...")
228
+
229
+ (df_train, X_train, X_test), (df_test, y_train, y_test) = load_data()
230
+
231
+ valid_columns = X_train.columns.to_list()
232
+
233
+ with gr.Blocks() as demo:
234
+
235
+ # Link + images
236
+ gr.Markdown(
237
+ """
238
+ <p align="center">
239
+ <img width=200 src="https://user-images.githubusercontent.com/5758427/197816413-d9cddad3-ba38-4793-847d-120975e1da11.png">
240
+ </p>
241
+
242
+ <h2 align="center">Health Prediction On Encrypted Data Using Homomorphic Encryption.</h2>
243
+
244
+ <p align="center">
245
+ <a href="https://github.com/zama-ai/concrete-ml"> <img style="vertical-align: middle; display:inline-block; margin-right: 3px;" width=15 src="https://user-images.githubusercontent.com/5758427/197972109-faaaff3e-10e2-4ab6-80f5-7531f7cfb08f.png">Concrete-ML</a>
246
+
247
+ <a href="https://docs.zama.ai/concrete-ml"> <img style="vertical-align: middle; display:inline-block; margin-right: 3px;" width=15 src="https://user-images.githubusercontent.com/5758427/197976802-fddd34c5-f59a-48d0-9bff-7ad1b00cb1fb.png">Documentation</a>
248
+
249
+ <a href="https://zama.ai/community"> <img style="vertical-align: middle; display:inline-block; margin-right: 3px;" width=15 src="https://user-images.githubusercontent.com/5758427/197977153-8c9c01a7-451a-4993-8e10-5a6ed5343d02.png">Community</a>
250
+
251
+ <a href="https://twitter.com/zama_fhe"> <img style="vertical-align: middle; display:inline-block; margin-right: 3px;" width=15 src="https://user-images.githubusercontent.com/5758427/197975044-bab9d199-e120-433b-b3be-abd73b211a54.png">@zama_fhe</a>
252
+ </p>
253
+
254
+ <p align="center">
255
+ <img src="https://raw.githubusercontent.com/kcelia/Img/main/demo-img2.png" width="60%" height="60%">
256
+ </p>
257
+ """
258
+ )
259
+
260
+ # Gentle introduction
261
+ gr.Markdown("## Introduction")
262
+ gr.Markdown("""Blablabla""")
263
+
264
+ # User symptoms
265
+ gr.Markdown("# Step 1: Provide your symptoms")
266
+ gr.Markdown("Client side")
267
+
268
+ # Default disease, picked from the dataframe
269
+ with gr.Row():
270
+ default_diseases = list(set(df_test["prognosis"]))
271
+ box_default = gr.Dropdown(default_diseases, label="Disease")
272
+
273
+ # Box symptoms
274
+ check_boxes = []
275
+ for i, category in enumerate(SYMPTOMS_LIST):
276
+ check_box = gr.CheckboxGroup(
277
+ pretty_print(category.values()),
278
+ label=pretty_print(category.keys()),
279
+ info=f"Symptoms related to `{pretty_print(category.values())}`",
280
+ max_batch_size=45,
281
+ )
282
+ check_boxes.append(check_box)
283
+
284
+ # User symptom vector
285
+ with gr.Row():
286
+ user_vector_textbox = gr.Textbox(
287
+ label="User symptoms (vector)",
288
+ interactive=False,
289
+ max_lines=100,
290
+ )
291
+ error_box = gr.Textbox(label="Error", visible=False)
292
+
293
+ with gr.Row():
294
+ # Submit botton
295
+ with gr.Column():
296
+ submit_button = gr.Button("Submit")
297
+ # Clear botton
298
+ with gr.Column():
299
+ clear_button = gr.Button("Clear", style="background-color: yellow;")
300
+
301
+ # Click submit botton
302
+
303
+ submit_button.click(
304
+ fn=get_user_symptoms_vector,
305
+ inputs=[box_default, *check_boxes],
306
+ outputs=[user_vector_textbox, error_box, *check_boxes],
307
+ )
308
+ # Load the model
309
+ concrete_classifier = load(
310
+ open("ConcreteRandomForestClassifier.pkl", "r", encoding="utf-8")
311
+ )
312
+
313
+ gr.Markdown("# Step 2: Generate the keys")
314
+ gr.Markdown("Client side")
315
+
316
+ gen_key = gr.Button("Generate the keys and send public part to server")
317
+
318
+ with gr.Row():
319
+ # User ID
320
+ with gr.Column(scale=1, min_width=600):
321
+ user_id_textbox = gr.Textbox(
322
+ label="User ID:",
323
+ max_lines=4,
324
+ interactive=False,
325
+ )
326
+ # Evaluation key size
327
+ with gr.Column(scale=1, min_width=600):
328
+ eval_key_len_textbox = gr.Textbox(
329
+ label="Evaluation key size:", max_lines=4, interactive=False
330
+ )
331
+
332
+ with gr.Row():
333
+ # Evaluation key (truncated)
334
+ with gr.Column(scale=2, min_width=600):
335
+ eval_key_textbox = gr.Textbox(
336
+ label="Evaluation key (truncated):",
337
+ max_lines=4,
338
+ interactive=False,
339
+ )
340
+
341
+ gen_key.click(key_gen, outputs=[eval_key_textbox, user_id_textbox, eval_key_len_textbox])
342
+
343
+ clear_button.click(
344
+ clear_all_buttons,
345
+ outputs=[
346
+ user_id_textbox,
347
+ user_vector_textbox,
348
+ eval_key_textbox,
349
+ eval_key_len_textbox,
350
+ box_default,
351
+ error_box,
352
+ *check_boxes,
353
+ ],
354
+ )
355
+
356
+ gr.Markdown("# Step 3: Encode the message with the private key")
357
+ gr.Markdown("Client side")
358
+
359
+ encode_msg = gr.Button("Generate the keys and send public part to server")
360
+
361
+ with gr.Row():
362
+
363
+ with gr.Column(scale=1, min_width=600):
364
+ vect_textbox = gr.Textbox(
365
+ label="Vector:",
366
+ max_lines=4,
367
+ interactive=False,
368
+ )
369
+
370
+ with gr.Column(scale=1, min_width=600):
371
+ quant_vect_textbox = gr.Textbox(
372
+ label="Quant vector:", max_lines=4, interactive=False
373
+ )
374
+
375
+ with gr.Column(scale=1, min_width=600):
376
+ encrypted_vect_textbox = gr.Textbox(
377
+ label="Encrypted vector:", max_lines=4, interactive=False
378
+ )
379
+
380
+ encode_msg.click(
381
+ encode_quantize_encrypt,
382
+ inputs=[user_vector_textbox, user_id_textbox],
383
+ outputs=[vect_textbox, quant_vect_textbox, encrypted_vect_textbox],
384
+ )
385
+
386
+ gr.Markdown("# Step 4: Run the FHE evaluation")
387
+ gr.Markdown("Server side")
388
+
389
+ run_fhe = gr.Button("Run the FHE evaluation")
390
+
391
+ gr.Markdown("# Step 5: Decrypt the sentiment")
392
+ gr.Markdown("Server side")
393
+
394
+ decrypt_target_botton = gr.Button("Decrypt the sentiment")
395
+ decrypt_target_textbox = gr.Textbox(
396
+ label="Encrypted vector:", max_lines=4, interactive=False
397
+ )
398
+
399
+ decrypt_target_botton.click(
400
+ decrypt_prediction,
401
+ inputs=[encrypted_vect_textbox, user_id_textbox],
402
+ outputs=[decrypt_target_textbox],
403
+ )
404
+
405
+ demo.launch()
client_folder/client.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d74f69c8847ee0c4d1d1828eea2d81ae0e9f20de866bb8536d391541d68c8f04
3
+ size 89862
client_folder/server.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3240edb4a0f896e56a7a077bf7ebc83a23003c96f96c5096cc80898152053f5b
3
+ size 1778
client_folder/versions.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"concrete-ml": "1.0.0rc2", "concrete-python": "1.0.0", "python": "3.10.6"}
preprocessing.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Preliminary preprocessing on the data, such as:
3
+ - correcting column names
4
+ - encoding the target column
5
+ """
6
+
7
+ import pandas as pd
8
+ from sklearn import preprocessing
9
+
10
+ COLUMNS_TO_DROP = ["Unnamed: 133"]
11
+ TARGET_COLUMN = ["prognosis"]
12
+ RENAME_COLUMNS = {
13
+ "scurring": "scurving",
14
+ "dischromic _patches": "dischromic_patches",
15
+ "spotting_ urination": "spotting_urination",
16
+ "foul_smell_of urine": "foul_smell_of_urine",
17
+ }
18
+
19
+
20
+ def pretty_print(input):
21
+ """
22
+ Prettify the input.
23
+
24
+ Args:
25
+ input: Can be a list of symtoms or a disease.
26
+
27
+ Returns:
28
+ list: Sorted and prettified input.
29
+ """
30
+ # Convert to a list if necessary
31
+ if isinstance(input, list):
32
+ input = list(input)
33
+
34
+ # Flatten the list if required
35
+ pretty_list = []
36
+ for item in input:
37
+ if isinstance(item, list):
38
+ pretty_list.extend(item)
39
+ else:
40
+ pretty_list.append(item)
41
+
42
+ # Sort and prettify the input
43
+ pretty_list = sorted([" ".join((item.split("_"))).title() for item in pretty_list])
44
+
45
+ return pretty_list
46
+
47
+
48
+ def map_prediction(target_columns=["y", "prognosis"]):
49
+ df = pd.read_csv("Training_preprocessed.csv")
50
+ relevent_df = df[target_columns].drop_duplicates().relevent_df.where(df["y"] == 1)
51
+ prediction = relevent_df[target_columns[1]].dropna().values[0]
52
+ return prediction
53
+
54
+
55
+ if __name__ == "__main__":
56
+
57
+ # Load data
58
+ df_train = pd.read_csv("Training.csv")
59
+ df_test = pd.read_csv("Testing.csv")
60
+
61
+ # Remove unseless columns
62
+ df_train.drop(columns=COLUMNS_TO_DROP, axis=1, errors="ignore", inplace=True)
63
+ df_test.drop(columns=COLUMNS_TO_DROP, axis=1, errors="ignore", inplace=True)
64
+
65
+ # Correct some typos in some columns name
66
+ df_train.rename(columns=RENAME_COLUMNS, inplace=True)
67
+ df_test.rename(columns=RENAME_COLUMNS, inplace=True)
68
+
69
+ # Convert y category labels to y
70
+ label_encoder = preprocessing.LabelEncoder()
71
+ label_encoder.fit(df_train[TARGET_COLUMN].values.flatten())
72
+
73
+ df_train["y"] = label_encoder.transform(df_train[TARGET_COLUMN].values.flatten())
74
+ df_test["y"] = label_encoder.transform(df_test[TARGET_COLUMN].values.flatten())
75
+
76
+ # Cast X features from int64 to float32
77
+ float_columns = df_train.columns.drop(TARGET_COLUMN)
78
+ df_train[float_columns] = df_train[float_columns].astype("float32")
79
+ df_test[float_columns] = df_test[float_columns].astype("float32")
80
+
81
+ # Save preprocessed data
82
+ df_train.to_csv(path_or_buf="Training_preprocessed.csv", index=False)
83
+ df_test.to_csv(path_or_buf="Testing_preprocessed.csv", index=False)
server.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Server that will listen for GET and POST requests from the client."""
2
+
3
+ import time
4
+ from pathlib import Path
5
+ from typing import List
6
+
7
+ from fastapi import FastAPI, File, Form, UploadFile
8
+ from fastapi.responses import JSONResponse, Response
9
+
10
+ from concrete.ml.deployment import FHEModelServer
11
+
12
+ # Initialize an instance of FastAPI
13
+ app = FastAPI()
14
+
15
+ current_dir = Path(__file__).parent
16
+
17
+ # Load the model
18
+ fhe_model = FHEModelServer(Path.joinpath(current_dir, "./client_folder"))
19
+
20
+ # Define the default route
21
+ @app.get("/")
22
+ def root():
23
+ return {"message": "Welcome to Your disease prediction with fhe !"}
24
+
25
+
26
+ @app.post("/send_input")
27
+ def send_input(
28
+ user_id: str = Form(),
29
+ filter: str = Form(),
30
+ files: List[UploadFile] = File(),
31
+ ):
32
+ """Send the inputs to the server."""
33
+ # Retrieve the encrypted input image and the evaluation key paths
34
+ encrypted_image_path = 0 # Tcurrent_dir("encrypted_image", user_id, filter)
35
+ evaluation_key_path = current_dir / ".fhe_keys/{user_id}"
36
+
37
+ # Write the files using the above paths
38
+ with encrypted_image_path.open("wb") as encrypted_image, evaluation_key_path.open(
39
+ "wb"
40
+ ) as evaluation_key:
41
+ encrypted_image.write(files[0].file.read())
42
+ evaluation_key.write(files[1].file.read())
43
+
44
+
45
+ @app.post("/run_fhe")
46
+ def run_fhe(
47
+ user_id: str = Form(),
48
+ filter: str = Form(),
49
+ ):
50
+ """Execute the filter on the encrypted input image using FHE."""
51
+ # Retrieve the encrypted input image and the evaluation key paths
52
+ encrypted_image_path = get_server_file_path("encrypted_image", user_id, filter)
53
+ evaluation_key_path = get_server_file_path("evaluation_key", user_id, filter)
54
+
55
+ # Read the files using the above paths
56
+ with encrypted_image_path.open("rb") as encrypted_image_file, evaluation_key_path.open(
57
+ "rb"
58
+ ) as evaluation_key_file:
59
+ encrypted_image = encrypted_image_file.read()
60
+ evaluation_key = evaluation_key_file.read()
61
+
62
+ # Load the FHE server
63
+ fhe_server = FHEServer(FILTERS_PATH / f"{filter}/deployment")
64
+
65
+ # Run the FHE execution
66
+ start = time.time()
67
+ encrypted_output_image = fhe_server.run(encrypted_image, evaluation_key)
68
+ fhe_execution_time = round(time.time() - start, 2)
69
+
70
+ # Retrieve the encrypted output image path
71
+ encrypted_output_path = get_server_file_path("encrypted_output", user_id, filter)
72
+
73
+ # Write the file using the above path
74
+ with encrypted_output_path.open("wb") as encrypted_output:
75
+ encrypted_output.write(encrypted_output_image)
76
+
77
+ return JSONResponse(content=fhe_execution_time)
78
+
79
+
80
+ @app.post("/get_output")
81
+ def get_output(
82
+ user_id: str = Form(),
83
+ filter: str = Form(),
84
+ ):
85
+ """Retrieve the encrypted output image."""
86
+ # Retrieve the encrypted output image path
87
+ encrypted_output_path = get_server_file_path("encrypted_output", user_id, filter)
88
+
89
+ # Read the file using the above path
90
+ with encrypted_output_path.open("rb") as encrypted_output_file:
91
+ encrypted_output = encrypted_output_file.read()
92
+
93
+ return Response(encrypted_output)
symptoms_categories.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ In this file, we roughly split up a list of symptoms, taken from "./training.csv" file, avalaible
3
+ through: "https://github.com/anujdutt9/Disease-Prediction-from-Symptoms/tree/master/dataset"
4
+ into medical categories, in order to make the UI more plesant for the users.
5
+
6
+ Each variable contains a list of symptoms sthat can be pecific to a part of the body or to a list
7
+ of similar symptoms.
8
+ """
9
+
10
+ import itertools
11
+
12
+ import pandas as pd
13
+
14
+ DIGESTIVE_SYSTEM_SYPTOMS = {
15
+ "Digestive system syptoms": [
16
+ "stomach_pain",
17
+ "acidity",
18
+ "vomiting",
19
+ "indigestion",
20
+ "constipation",
21
+ "abdominal_pain",
22
+ "diarrhoea",
23
+ "belly_pain",
24
+ "nausea",
25
+ "distention_of_abdomen",
26
+ "stomach_bleeding",
27
+ "pain_during_bowel_movements",
28
+ "passage_of_gases",
29
+ "brittle_nails",
30
+ "red_spots_over_body",
31
+ "swelling_of_stomach",
32
+ "bloody_stool",
33
+ "yellowish_skin",
34
+ "irritation_in_anus",
35
+ "pain_in_anal_region",
36
+ "abnormal_menstruation",
37
+ ]
38
+ }
39
+
40
+ SKIN_SYPTOMS = {
41
+ "Skin related symptoms": [
42
+ "itching",
43
+ "skin_rash",
44
+ "pus_filled_pimples",
45
+ "blackheads",
46
+ "scurving",
47
+ "skin_peeling",
48
+ "silver_like_dusting",
49
+ "small_dents_in_nails",
50
+ "inflammatory_nails",
51
+ "blister",
52
+ "red_sore_around_nose",
53
+ "bruising",
54
+ "yellow_crust_ooze",
55
+ "dischromic_patches",
56
+ "nodal_skin_eruptions",
57
+ ]
58
+ }
59
+
60
+ ORL_SYPTOMS = {
61
+ "ORL_SYPTOMS": [
62
+ "loss_of_smell",
63
+ "continuous_sneezing",
64
+ "runny_nose",
65
+ "patches_in_throat",
66
+ "throat_irritation",
67
+ "sinus_pressure",
68
+ "enlarged_thyroid",
69
+ "loss_of_balance",
70
+ "unsteadiness",
71
+ "dizziness",
72
+ "spinning_movements",
73
+ ]
74
+ }
75
+
76
+ THORAX_SYMPTOMS = {
77
+ "THORAX_RELATED_SYMPTOMS": [
78
+ "breathlessness",
79
+ "chest_pain",
80
+ "cough",
81
+ "rusty_sputum",
82
+ "phlegm",
83
+ "mucoid_sputum",
84
+ "congestion",
85
+ "blood_in_sputum",
86
+ "fast_heart_rate",
87
+ ]
88
+ }
89
+
90
+ EYES_SYMPTOMS = {
91
+ "Eyes_related_symptoms": [
92
+ "sunken_eyes",
93
+ "redness_of_eyes",
94
+ "watering_from_eyes",
95
+ "blurred_and_distorted_vision",
96
+ "pain_behind_the_eyes",
97
+ "visual_disturbances",
98
+ ]
99
+ }
100
+
101
+ VASCULAR_LYMPHATIC_SYMPTOMS = {
102
+ "VASCULAR_LYMPHATIC_SYMPTOMS": [
103
+ "cold_hands_and_feets",
104
+ "swollen_blood_vessels",
105
+ "swollen_legs",
106
+ "swelled_lymph_nodes",
107
+ "palpitations",
108
+ "prominent_veins_on_calf",
109
+ "yellowing_of_eyes",
110
+ "puffy_face_and_eyes",
111
+ "fluid_overload",
112
+ "fluid_overload.1",
113
+ "swollen_extremeties",
114
+ ]
115
+ }
116
+
117
+ UROLOGICAL_SYMPTOMS = {
118
+ "UROLOGICAL_SYMPTOMS": [
119
+ "burning_micturition",
120
+ "spotting_urination",
121
+ "yellow_urine",
122
+ "bladder_discomfort",
123
+ "foul_smell_of_urine",
124
+ "continuous_feel_of_urine",
125
+ "polyuria",
126
+ "dark_urine",
127
+ ]
128
+ }
129
+
130
+ MUSCULOSKELETAL_SYMPTOMS = {
131
+ "MUSCULOSKELETAL_SYMPTOMS": [
132
+ "joint_pain",
133
+ "muscle_wasting",
134
+ "muscle_pain",
135
+ "muscle_weakness",
136
+ "knee_pain",
137
+ "stiff_neck",
138
+ "swelling_joints",
139
+ "movement_stiffness",
140
+ "hip_joint_pain",
141
+ "painful_walking",
142
+ "weakness_of_one_body_side",
143
+ "neck_pain",
144
+ "back_pain",
145
+ "weakness_in_limbs",
146
+ "cramps",
147
+ ]
148
+ }
149
+
150
+ FEELING_SYMPTOMS = {
151
+ "FEELING_SYPTOMS": [
152
+ "anxiety",
153
+ "restlessness",
154
+ "lethargy",
155
+ "mood_swings",
156
+ "depression",
157
+ "irritability",
158
+ "lack_of_concentration",
159
+ "fatigue",
160
+ "malaise",
161
+ "weight_gain",
162
+ "increased_appetite",
163
+ "weight_loss",
164
+ "loss_of_appetite",
165
+ "obesity",
166
+ "excessive_hunger",
167
+ ]
168
+ }
169
+
170
+ OTHER_SYPTOMS = {
171
+ "OTHER_SYPTOMS": [
172
+ "ulcers_on_tongue",
173
+ "shivering",
174
+ "chills",
175
+ "irregular_sugar_level",
176
+ "high_fever",
177
+ "slurred_speech",
178
+ "sweating",
179
+ "internal_itching",
180
+ "mild_fever",
181
+ "toxic_look_(typhos)",
182
+ "acute_liver_failure",
183
+ "dehydration",
184
+ "headache",
185
+ "extra_marital_contacts",
186
+ "drying_and_tingling_lips",
187
+ "altered_sensorium",
188
+ ]
189
+ }
190
+
191
+ PATIENT_HISTORY = {
192
+ "PATIENT_HISTORY": [
193
+ "family_history",
194
+ "receiving_blood_transfusion",
195
+ "receiving_unsterile_injections",
196
+ "history_of_alcohol_consumption",
197
+ "coma",
198
+ ]
199
+ }
200
+
201
+ SYMPTOMS_LIST = [
202
+ SKIN_SYPTOMS,
203
+ EYES_SYMPTOMS,
204
+ ORL_SYPTOMS,
205
+ THORAX_SYMPTOMS,
206
+ DIGESTIVE_SYSTEM_SYPTOMS,
207
+ UROLOGICAL_SYMPTOMS,
208
+ VASCULAR_LYMPHATIC_SYMPTOMS,
209
+ MUSCULOSKELETAL_SYMPTOMS,
210
+ FEELING_SYMPTOMS,
211
+ PATIENT_HISTORY,
212
+ OTHER_SYPTOMS,
213
+ ]
214
+
215
+
216
+ def test(file_path="./Training.csv"):
217
+ df = pd.read_csv(file_path, index_col=0)
218
+ valid_column = df.columns
219
+ all_symptoms = [category.values() for category in SYMPTOMS_LIST]
220
+ all_symptoms = list(itertools.chain.from_iterable(all_symptoms))
221
+ all_symptoms = list(itertools.chain.from_iterable(all_symptoms))
222
+ set(valid_column) - set(all_symptoms), set(all_symptoms) - set(valid_column)
223
+
224
+
225
+ if __name__ == "__main__":
226
+ test()