VaultChem commited on
Commit
dd8b714
1 Parent(s): aaffad9

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -7
app.py CHANGED
@@ -1,14 +1,198 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
 
4
  def update(name):
5
  return f"Welcome to Gradio, {name}!"
6
 
7
 
8
- with gr.Blocks() as app:
9
- gr.Markdown("### Start typing below and then click **Run** to see the output.")
10
- with gr.Column():
11
- gr.File(interactive=False)
12
- gr.UploadButton(label="Upload an executable file")
13
-
14
- app.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import time
3
+ import os
4
+ from pathlib import Path
5
+
6
+ import subprocess
7
+ from concrete.ml.deployment import FHEModelClient
8
+ from requests import head
9
+ import numpy
10
+ import os
11
+ from pathlib import Path
12
+ import requests
13
+ import json
14
+ import base64
15
+ import subprocess
16
+ import shutil
17
+ import time
18
+ import pandas as pd
19
+ import pickle
20
+ import numpy as np
21
+
22
+ # This repository's directory
23
+ REPO_DIR = Path(__file__).parent
24
+ subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
25
+
26
+ # if not exists, create a directory for the FHE keys called .fhe_keys
27
+ if not os.path.exists(".fhe_keys"):
28
+ os.mkdir(".fhe_keys")
29
+ # if not exists, create a directory for the tmp files called tmp
30
+ if not os.path.exists("tmp"):
31
+ os.mkdir("tmp")
32
+
33
+
34
+ # Wait 4 sec for the server to start
35
+ time.sleep(4)
36
+
37
+
38
+ # Encrypted data limit for the browser to display
39
+ # (encrypted data is too large to display in the browser)
40
+ ENCRYPTED_DATA_BROWSER_LIMIT = 500
41
+ N_USER_KEY_STORED = 20
42
+
43
+
44
+ def clean_tmp_directory():
45
+ # Allow 20 user keys to be stored.
46
+ # Once that limitation is reached, deleted the oldest.
47
+ path_sub_directories = sorted(
48
+ [f for f in Path(".fhe_keys/").iterdir() if f.is_dir()], key=os.path.getmtime
49
+ )
50
+
51
+ user_ids = []
52
+ if len(path_sub_directories) > N_USER_KEY_STORED:
53
+ n_files_to_delete = len(path_sub_directories) - N_USER_KEY_STORED
54
+ for p in path_sub_directories[:n_files_to_delete]:
55
+ user_ids.append(p.name)
56
+ shutil.rmtree(p)
57
+
58
+ list_files_tmp = Path("tmp/").iterdir()
59
+ # Delete all files related to user_id
60
+ for file in list_files_tmp:
61
+ for user_id in user_ids:
62
+ if file.name.endswith(f"{user_id}.npy"):
63
+ file.unlink()
64
+
65
+
66
+ def keygen():
67
+ # Clean tmp directory if needed
68
+ clean_tmp_directory()
69
+
70
+ print("Initializing FHEModelClient...")
71
+ # Let's create a user_id
72
+ user_id = numpy.random.randint(0, 2**32)
73
+ fhe_api = FHEModelClient(f"deployment/deployment_{task}", f".fhe_keys/{user_id}")
74
+ fhe_api.load()
75
+
76
+ # Generate a fresh key
77
+ fhe_api.generate_private_and_evaluation_keys(force=True)
78
+ evaluation_key = fhe_api.get_serialized_evaluation_keys()
79
+
80
+ numpy.save(f"tmp/tmp_evaluation_key_{user_id}.npy", evaluation_key)
81
+
82
+ return [list(evaluation_key)[:ENCRYPTED_DATA_BROWSER_LIMIT], user_id]
83
+
84
+
85
+ def encode_quantize_encrypt(test_file, user_id):
86
+
87
+ fhe_api = FHEModelClient(f"fhe_model", f".fhe_keys/{user_id}")
88
+ fhe_api.load()
89
+ from PE_main import extract_infos
90
+
91
+ features = pickle.loads(open(os.path.join("features.pkl"), "rb").read())
92
+ encodings = extract_infos(test_file)
93
+ encodings = list(map(lambda x: encodings[x], features))
94
+
95
+ quantized_encodings = fhe_api.model.quantize_input(encodings).astype(numpy.uint8)
96
+ encrypted_quantized_encoding = fhe_api.quantize_encrypt_serialize(encodings)
97
+
98
+ # Save encrypted_quantized_encoding in a file, since too large to pass through regular Gradio
99
+ # buttons, https://github.com/gradio-app/gradio/issues/1877
100
+ numpy.save(
101
+ f"tmp/tmp_encrypted_quantized_encoding_{user_id}.npy",
102
+ encrypted_quantized_encoding,
103
+ )
104
+
105
+ # Compute size
106
+ encrypted_quantized_encoding_shorten = list(encrypted_quantized_encoding)[
107
+ :ENCRYPTED_DATA_BROWSER_LIMIT
108
+ ]
109
+ encrypted_quantized_encoding_shorten_hex = "".join(
110
+ f"{i:02x}" for i in encrypted_quantized_encoding_shorten
111
+ )
112
+ return (
113
+ encodings[0],
114
+ quantized_encodings[0],
115
+ encrypted_quantized_encoding_shorten_hex,
116
+ )
117
+
118
+
119
+ def run_fhe(user_id):
120
+ encoded_data_path = Path(f"tmp/tmp_encrypted_quantized_encoding_{user_id}.npy")
121
+ encrypted_quantized_encoding = numpy.load(encoded_data_path)
122
+
123
+ # Read evaluation_key from the file
124
+ evaluation_key = numpy.load(f"tmp/tmp_evaluation_key_{user_id}.npy")
125
+
126
+ # Use base64 to encode the encodings and evaluation key
127
+ encrypted_quantized_encoding = base64.b64encode(
128
+ encrypted_quantized_encoding
129
+ ).decode()
130
+ encoded_evaluation_key = base64.b64encode(evaluation_key).decode()
131
+
132
+ query = {}
133
+ query["evaluation_key"] = encoded_evaluation_key
134
+ query["encrypted_encoding"] = encrypted_quantized_encoding
135
+ headers = {"Content-type": "application/json"}
136
+
137
+
138
+ response = requests.post(
139
+ "http://localhost:8000/predict",
140
+ data=json.dumps(query),
141
+ headers=headers,
142
+ )
143
+
144
+
145
+ encrypted_prediction = base64.b64decode(response.json()["encrypted_prediction"])
146
+
147
+ numpy.save(f"tmp/tmp_encrypted_prediction_{user_id}.npy", encrypted_prediction)
148
+ encrypted_prediction_shorten = list(encrypted_prediction)[
149
+ :ENCRYPTED_DATA_BROWSER_LIMIT
150
+ ]
151
+ encrypted_prediction_shorten_hex = "".join(
152
+ f"{i:02x}" for i in encrypted_prediction_shorten
153
+ )
154
+
155
+
156
+ def decrypt_prediction(user_id):
157
+ encoded_data_path = Path(f"tmp/tmp_encrypted_prediction_{user_id}.npy")
158
+
159
+ # Read encrypted_prediction from the file
160
+
161
+ encrypted_prediction = numpy.load(encoded_data_path).tobytes()
162
+
163
+ fhe_api = FHEModelClient(f"fhe_model", f".fhe_keys/{user_id}")
164
+ fhe_api.load()
165
+
166
+ # We need to retrieve the private key that matches the client specs (see issue #18)
167
+ fhe_api.generate_private_and_evaluation_keys(force=False)
168
+
169
+ predictions = fhe_api.deserialize_decrypt_dequantize(encrypted_prediction)
170
 
171
 
172
  def update(name):
173
  return f"Welcome to Gradio, {name}!"
174
 
175
 
176
+ if __name__ == "__main__":
177
+ app = gr.Interface(
178
+ [
179
+ keygen,
180
+ encode_quantize_encrypt,
181
+ run_fhe,
182
+ decrypt_prediction,
183
+ ],
184
+ [
185
+ gr.inputs.Textbox(label="Task", default="malware"),
186
+ gr.inputs.File(label="Test File"),
187
+ gr.inputs.Textbox(label="User ID"),
188
+ ],
189
+ [
190
+ gr.outputs.Textbox(label="Evaluation Key"),
191
+ gr.outputs.Textbox(label="Encodings"),
192
+ gr.outputs.Textbox(label="Encrypted Quantized Encoding"),
193
+ gr.outputs.Textbox(label="Encrypted Prediction"),
194
+ ],
195
+ title="FHE Model",
196
+ description="This is a FHE Model",
197
+ )
198
+ app.launch()