Roman commited on
Commit
3cf0931
•
1 Parent(s): 79ec538

chore: Add comments, clean unused objects and improve ridge detection

Browse files
app.py CHANGED
@@ -19,7 +19,7 @@ from common import (
19
  REPO_DIR,
20
  SERVER_URL,
21
  )
22
- from custom_client_server import CustomFHEClient, CustomFHEServer
23
 
24
  # Uncomment here to have both the server and client in the same terminal
25
  subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
@@ -27,11 +27,16 @@ time.sleep(3)
27
 
28
 
29
  def decrypt_output_with_wrong_key(encrypted_image, image_filter):
 
 
 
30
  filter_path = FILTERS_PATH / f"{image_filter}/deployment"
31
 
 
32
  wrong_client = CustomFHEClient(filter_path, WRONG_KEYS_PATH)
33
  wrong_client.generate_private_and_evaluation_keys(force=True)
34
 
 
35
  output_image = wrong_client.deserialize_decrypt_post_process(encrypted_image)
36
 
37
  return output_image
 
19
  REPO_DIR,
20
  SERVER_URL,
21
  )
22
+ from custom_client_server import CustomFHEClient
23
 
24
  # Uncomment here to have both the server and client in the same terminal
25
  subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
 
27
 
28
 
29
  def decrypt_output_with_wrong_key(encrypted_image, image_filter):
30
+ """Decrypt the encrypted output using a different private key.
31
+ """
32
+ # Retrieve the filter's deployment path
33
  filter_path = FILTERS_PATH / f"{image_filter}/deployment"
34
 
35
+ # Instantiate the client interface and generate a new private key
36
  wrong_client = CustomFHEClient(filter_path, WRONG_KEYS_PATH)
37
  wrong_client.generate_private_and_evaluation_keys(force=True)
38
 
39
+ # Deserialize, decrypt and post-processing the encrypted output using the new private key
40
  output_image = wrong_client.deserialize_decrypt_post_process(encrypted_image)
41
 
42
  return output_image
common.py CHANGED
@@ -2,26 +2,23 @@
2
 
3
  from pathlib import Path
4
 
5
- import numpy as np
6
- from PIL import Image
7
-
8
- # The repository's directory
9
  REPO_DIR = Path(__file__).parent
10
 
11
- # The repository's main directories
12
  FILTERS_PATH = REPO_DIR / "filters"
13
  KEYS_PATH = REPO_DIR / ".fhe_keys"
14
  WRONG_KEYS_PATH = REPO_DIR / ".wrong_keys"
15
  CLIENT_TMP_PATH = REPO_DIR / "client_tmp"
16
  SERVER_TMP_PATH = REPO_DIR / "server_tmp"
17
 
18
- # Create the directories if it does not exist yet
19
  KEYS_PATH.mkdir(exist_ok=True)
20
  WRONG_KEYS_PATH.mkdir(exist_ok=True)
21
  CLIENT_TMP_PATH.mkdir(exist_ok=True)
22
  SERVER_TMP_PATH.mkdir(exist_ok=True)
23
 
24
- # All the filters currently available in the app
25
  AVAILABLE_FILTERS = [
26
  "identity",
27
  "inverted",
@@ -32,25 +29,14 @@ AVAILABLE_FILTERS = [
32
  "ridge detection",
33
  ]
34
 
35
- # The input image's shape. Images with larger input shapes will be cropped and/or resized to this
36
  INPUT_SHAPE = (100, 100)
37
 
38
- # Generate random images as an inputset for compilation
39
- np.random.seed(42)
40
- INPUTSET = tuple(
41
- np.random.randint(0, 255, size=(INPUT_SHAPE + (3,)), dtype=np.int64) for _ in range(10)
42
- )
43
-
44
-
45
- def load_image(image_path):
46
- image = Image.open(image_path).convert("RGB").resize(INPUT_SHAPE)
47
- image = np.asarray(image, dtype="int64")
48
- return image
49
-
50
-
51
- _INPUTSET_DIR = REPO_DIR / "input_examples"
52
 
53
- # List of all image examples suggested in the app
54
- EXAMPLES = [str(image) for image in _INPUTSET_DIR.glob("**/*")]
55
 
 
56
  SERVER_URL = "http://localhost:8000/"
 
2
 
3
  from pathlib import Path
4
 
5
+ # This repository's directory
 
 
 
6
  REPO_DIR = Path(__file__).parent
7
 
8
+ # This repository's main necessary folders
9
  FILTERS_PATH = REPO_DIR / "filters"
10
  KEYS_PATH = REPO_DIR / ".fhe_keys"
11
  WRONG_KEYS_PATH = REPO_DIR / ".wrong_keys"
12
  CLIENT_TMP_PATH = REPO_DIR / "client_tmp"
13
  SERVER_TMP_PATH = REPO_DIR / "server_tmp"
14
 
15
+ # Create the necessary folders
16
  KEYS_PATH.mkdir(exist_ok=True)
17
  WRONG_KEYS_PATH.mkdir(exist_ok=True)
18
  CLIENT_TMP_PATH.mkdir(exist_ok=True)
19
  SERVER_TMP_PATH.mkdir(exist_ok=True)
20
 
21
+ # All the filters currently available in the demo
22
  AVAILABLE_FILTERS = [
23
  "identity",
24
  "inverted",
 
29
  "ridge detection",
30
  ]
31
 
32
+ # The input images' shape. Images with different input shapes will be cropped and resized by Gradio
33
  INPUT_SHAPE = (100, 100)
34
 
35
+ # Retrieve the input examples directory
36
+ INPUT_EXAMPLES_DIR = REPO_DIR / "input_examples"
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # List of all image examples suggested in the demo
39
+ EXAMPLES = [str(image) for image in INPUT_EXAMPLES_DIR.glob("**/*")]
40
 
41
+ # Store the server's URL
42
  SERVER_URL = "http://localhost:8000/"
compile.py CHANGED
@@ -2,10 +2,8 @@
2
 
3
  import json
4
  import shutil
5
-
6
- import numpy as np
7
  import onnx
8
- from common import AVAILABLE_FILTERS, FILTERS_PATH, INPUT_SHAPE, INPUTSET, KEYS_PATH
9
  from custom_client_server import CustomFHEClient, CustomFHEDev
10
 
11
  print("Starting compiling the filters.")
@@ -13,18 +11,17 @@ print("Starting compiling the filters.")
13
  for image_filter in AVAILABLE_FILTERS:
14
  print("\nCompiling filter:", image_filter)
15
 
16
- # Load the onnx model
17
- onnx_model = onnx.load(FILTERS_PATH / f"{image_filter}/server.onnx")
18
-
19
  deployment_path = FILTERS_PATH / f"{image_filter}/deployment"
20
 
21
- # Retrieve the client API related to the current filter
22
  model = CustomFHEClient(deployment_path, KEYS_PATH).model
23
 
24
- image_shape = INPUT_SHAPE + (3,)
 
25
 
26
- # Compile the model using the loaded onnx model
27
- model.compile(INPUTSET, onnx_model=onnx_model)
28
 
29
  processing_json_path = deployment_path / "serialized_processing.json"
30
 
@@ -36,11 +33,11 @@ for image_filter in AVAILABLE_FILTERS:
36
  if deployment_path.is_dir():
37
  shutil.rmtree(deployment_path)
38
 
39
- # Save the files needed for deployment
40
- fhe_api = CustomFHEDev(model=model, path_dir=deployment_path)
41
- fhe_api.save()
42
 
43
- # Write the serialized_processing.json file to the deployment folder
44
  with open(processing_json_path, "w") as f:
45
  json.dump(serialized_processing, f)
46
 
 
2
 
3
  import json
4
  import shutil
 
 
5
  import onnx
6
+ from common import AVAILABLE_FILTERS, FILTERS_PATH, KEYS_PATH
7
  from custom_client_server import CustomFHEClient, CustomFHEDev
8
 
9
  print("Starting compiling the filters.")
 
11
  for image_filter in AVAILABLE_FILTERS:
12
  print("\nCompiling filter:", image_filter)
13
 
14
+ # Retrieve the deployment files associated to the current filter
 
 
15
  deployment_path = FILTERS_PATH / f"{image_filter}/deployment"
16
 
17
+ # Retrieve the client associated to the current filter
18
  model = CustomFHEClient(deployment_path, KEYS_PATH).model
19
 
20
+ # Load the onnx model
21
+ onnx_model = onnx.load(FILTERS_PATH / f"{image_filter}/server.onnx")
22
 
23
+ # Compile the model on a representative inputset, using the loaded onnx model
24
+ model.compile(onnx_model=onnx_model)
25
 
26
  processing_json_path = deployment_path / "serialized_processing.json"
27
 
 
33
  if deployment_path.is_dir():
34
  shutil.rmtree(deployment_path)
35
 
36
+ # Save the development files needed for deployment
37
+ fhe_dev = CustomFHEDev(model=model, path_dir=deployment_path)
38
+ fhe_dev.save()
39
 
40
+ # Write the serialized_processing.json file in the deployment directory
41
  with open(processing_json_path, "w") as f:
42
  json.dump(serialized_processing, f)
43
 
custom_client_server.py CHANGED
@@ -1,4 +1,4 @@
1
- "Client-server interface implementation for custom models."
2
 
3
  from pathlib import Path
4
  from typing import Any
@@ -11,16 +11,16 @@ from concrete.ml.common.debugging.custom_assert import assert_true
11
 
12
 
13
  class CustomFHEDev:
14
- """Dev API to save the custom model and then load and run the FHE circuit."""
15
 
16
  model: Any = None
17
 
18
  def __init__(self, path_dir: str, model: Any = None):
19
- """Initialize the FHE API.
20
 
21
  Args:
22
- path_dir (str): the path to the directory where the circuit is saved
23
- model (Any): the model to use for the FHE API
24
  """
25
 
26
  self.path_dir = Path(path_dir)
@@ -33,7 +33,7 @@ class CustomFHEDev:
33
  """Export all needed artifacts for the client and server.
34
 
35
  Raises:
36
- Exception: path_dir is not empty
37
  """
38
  # Check if the path_dir is empty with pathlib
39
  listdir = list(Path(self.path_dir).glob("**/*"))
@@ -73,11 +73,11 @@ class CustomFHEClient:
73
  client: cnp.Client
74
 
75
  def __init__(self, path_dir: str, key_dir: str = None):
76
- """Initialize the FHE API.
77
 
78
  Args:
79
- path_dir (str): the path to the directory where the circuit is saved
80
- key_dir (str): the path to the directory where the keys are stored
81
  """
82
  self.path_dir = Path(path_dir)
83
  self.key_dir = Path(key_dir)
@@ -103,7 +103,7 @@ class CustomFHEClient:
103
  """Generate the private and evaluation keys.
104
 
105
  Args:
106
- force (bool): if True, regenerate the keys even if they already exist
107
  """
108
  self.client.keygen(force)
109
 
@@ -111,7 +111,7 @@ class CustomFHEClient:
111
  """Get the serialized evaluation keys.
112
 
113
  Returns:
114
- cnp.EvaluationKeys: the evaluation keys
115
  """
116
  return self.client.evaluation_keys.serialize()
117
 
@@ -119,10 +119,10 @@ class CustomFHEClient:
119
  """Encrypt and serialize the values.
120
 
121
  Args:
122
- x (numpy.ndarray): the values to encrypt and serialize
123
 
124
  Returns:
125
- cnp.PublicArguments: the encrypted and serialized values
126
  """
127
  # Pre-process the values
128
  x = self.model.pre_processing(x)
@@ -140,10 +140,10 @@ class CustomFHEClient:
140
  """Deserialize, decrypt and post-process the values.
141
 
142
  Args:
143
- serialized_encrypted_output (cnp.PublicArguments): the serialized and encrypted output
144
 
145
  Returns:
146
- numpy.ndarray: the decrypted values
147
  """
148
  # Deserialize the encrypted values
149
  deserialized_encrypted_output = self.client.specs.unserialize_public_result(
@@ -159,15 +159,15 @@ class CustomFHEClient:
159
 
160
 
161
  class CustomFHEServer:
162
- """Server API to load and run the FHE circuit."""
163
 
164
  server: cnp.Server
165
 
166
  def __init__(self, path_dir: str):
167
- """Initialize the FHE API.
168
 
169
  Args:
170
- path_dir (str): the path to the directory where the circuit is saved
171
  """
172
 
173
  self.path_dir = Path(path_dir)
@@ -187,11 +187,11 @@ class CustomFHEServer:
187
  """Run the model on the server over encrypted data.
188
 
189
  Args:
190
- serialized_encrypted_data (cnp.PublicArguments): the encrypted and serialized data
191
- serialized_evaluation_keys (cnp.EvaluationKeys): the serialized evaluation keys
192
 
193
  Returns:
194
- cnp.PublicResult: the result of the model
195
  """
196
  assert_true(self.server is not None, "Model has not been loaded.")
197
 
 
1
+ "Client-server interface implementation for custom integer models."
2
 
3
  from pathlib import Path
4
  from typing import Any
 
11
 
12
 
13
  class CustomFHEDev:
14
+ """Dev API to save the custom integer model, load and run a FHE circuit."""
15
 
16
  model: Any = None
17
 
18
  def __init__(self, path_dir: str, model: Any = None):
19
+ """Initialize the development interface.
20
 
21
  Args:
22
+ path_dir (str): The path to the directory where the circuit is saved.
23
+ model (Any): The model to use for the development interface.
24
  """
25
 
26
  self.path_dir = Path(path_dir)
 
33
  """Export all needed artifacts for the client and server.
34
 
35
  Raises:
36
+ Exception: path_dir is not empty.
37
  """
38
  # Check if the path_dir is empty with pathlib
39
  listdir = list(Path(self.path_dir).glob("**/*"))
 
73
  client: cnp.Client
74
 
75
  def __init__(self, path_dir: str, key_dir: str = None):
76
+ """Initialize the client interface.
77
 
78
  Args:
79
+ path_dir (str): The path to the directory where the circuit is saved.
80
+ key_dir (str): The path to the directory where the keys are stored.
81
  """
82
  self.path_dir = Path(path_dir)
83
  self.key_dir = Path(key_dir)
 
103
  """Generate the private and evaluation keys.
104
 
105
  Args:
106
+ force (bool): If True, regenerate the keys even if they already exist.
107
  """
108
  self.client.keygen(force)
109
 
 
111
  """Get the serialized evaluation keys.
112
 
113
  Returns:
114
+ cnp.EvaluationKeys: The evaluation keys.
115
  """
116
  return self.client.evaluation_keys.serialize()
117
 
 
119
  """Encrypt and serialize the values.
120
 
121
  Args:
122
+ x (numpy.ndarray): The values to encrypt and serialize.
123
 
124
  Returns:
125
+ cnp.PublicArguments: The encrypted and serialized values.
126
  """
127
  # Pre-process the values
128
  x = self.model.pre_processing(x)
 
140
  """Deserialize, decrypt and post-process the values.
141
 
142
  Args:
143
+ serialized_encrypted_output (cnp.PublicArguments): The serialized and encrypted output.
144
 
145
  Returns:
146
+ numpy.ndarray: The decrypted values.
147
  """
148
  # Deserialize the encrypted values
149
  deserialized_encrypted_output = self.client.specs.unserialize_public_result(
 
159
 
160
 
161
  class CustomFHEServer:
162
+ """Server interface to load and run a FHE circuit."""
163
 
164
  server: cnp.Server
165
 
166
  def __init__(self, path_dir: str):
167
+ """Initialize the server interface.
168
 
169
  Args:
170
+ path_dir (str): The path to the directory where the circuit is saved.
171
  """
172
 
173
  self.path_dir = Path(path_dir)
 
187
  """Run the model on the server over encrypted data.
188
 
189
  Args:
190
+ serialized_encrypted_data (cnp.PublicArguments): The encrypted and serialized data.
191
+ serialized_evaluation_keys (cnp.EvaluationKeys): The serialized evaluation keys.
192
 
193
  Returns:
194
+ cnp.PublicResult: The result of the model.
195
  """
196
  assert_true(self.server is not None, "Model has not been loaded.")
197
 
filters.py CHANGED
@@ -4,7 +4,7 @@ import json
4
 
5
  import numpy as np
6
  import torch
7
- from common import AVAILABLE_FILTERS
8
  from concrete.numpy.compilation.compiler import Compiler
9
  from torch import nn
10
 
@@ -63,17 +63,18 @@ class _TorchRotate(nn.Module):
63
  class _TorchConv2D(nn.Module):
64
  """Torch model for applying a single 2D convolution operator on images."""
65
 
66
- def __init__(self, kernel, n_in_channels=3, n_out_channels=3, groups=1):
67
- """Initializing the filter
68
 
69
  Args:
70
  kernel (np.ndarray): The convolution kernel to consider.
71
  """
72
  super().__init__()
73
- self.kernel = kernel
74
  self.n_out_channels = n_out_channels
75
  self.n_in_channels = n_in_channels
76
  self.groups = groups
 
77
 
78
  def forward(self, x):
79
  """Forward pass for filtering the image using a 2D kernel.
@@ -113,7 +114,14 @@ class _TorchConv2D(nn.Module):
113
  f"{kernel_shape}"
114
  )
115
 
116
- return nn.functional.conv2d(x, kernel, stride=stride, groups=self.groups)
 
 
 
 
 
 
 
117
 
118
 
119
  class Filter:
@@ -135,6 +143,8 @@ class Filter:
135
  )
136
 
137
  self.filter = image_filter
 
 
138
  self.divide = None
139
  self.repeat_out_channels = False
140
 
@@ -156,65 +166,62 @@ class Filter:
156
  # However, since FHE computations require weights to be integers, we first multiply
157
  # these by a factor of 1000. The output image's values are then divided by 1000 in
158
  # post-processing in order to retrieve the correct result
159
- kernel = torch.tensor([299, 587, 114])
160
 
161
  self.torch_model = _TorchConv2D(kernel, n_out_channels=1, groups=1)
162
 
163
- # Division value for post-processing
164
  self.divide = 1000
165
 
166
- # Grayscaled image needs to be put in RGB format for Gradio display
 
167
  self.repeat_out_channels = True
168
 
169
  elif image_filter == "blur":
170
- kernel = torch.ones((3, 3), dtype=torch.int64)
171
 
172
  self.torch_model = _TorchConv2D(kernel, n_out_channels=3, groups=3)
173
 
174
- # Division value for post-processing
175
  self.divide = 9
176
 
177
  elif image_filter == "sharpen":
178
- kernel = torch.tensor(
179
- [
180
- [0, -1, 0],
181
- [-1, 5, -1],
182
- [0, -1, 0],
183
- ]
184
- )
185
 
186
  self.torch_model = _TorchConv2D(kernel, n_out_channels=3, groups=3)
187
 
188
  elif image_filter == "ridge detection":
189
- kernel = torch.tensor(
190
- [
191
- [-1, -1, -1],
192
- [-1, 9, -1],
193
- [-1, -1, -1],
194
- ]
195
- )
196
-
197
- self.torch_model = _TorchConv2D(kernel, n_out_channels=1, groups=1)
198
-
199
- # Ridge detection is usually displayed as a grayscaled image, which needs to be put in
200
- # RGB format for Gradio display
 
201
  self.repeat_out_channels = True
202
 
203
- self.onnx_model = None
204
- self.fhe_circuit = None
205
-
206
- def compile(self, inputset, onnx_model=None):
207
- """Compile the model using an inputset.
208
 
209
  Args:
210
- inputset (List[np.ndarray]): The set of images to use for compilation
211
  onnx_model (onnx.ModelProto): The loaded onnx model to consider. If None, it will be
212
  generated automatically using a NumpyModule. Default to None.
213
  """
214
- # Reshape the inputs found in inputset. This is done because Torch and Numpy don't follow
215
- # the same shape conventions.
 
216
  inputset = tuple(
217
- np.expand_dims(input.transpose(2, 0, 1), axis=0).astype(np.int64) for input in inputset
218
  )
219
 
220
  # If no onnx model was given, generate a new one.
@@ -243,30 +250,31 @@ class Filter:
243
  return self.fhe_circuit
244
 
245
  def pre_processing(self, input_image):
246
- """Processing that needs to be applied before encryption.
247
 
248
  Args:
249
- input_image (np.ndarray): The image to pre-process
250
 
251
  Returns:
252
- input_image (np.ndarray): The pre-processed image
253
  """
254
  # Reshape the inputs found in inputset. This is done because Torch and Numpy don't follow
255
  # the same shape conventions.
 
256
  input_image = np.expand_dims(input_image.transpose(2, 0, 1), axis=0).astype(np.int64)
257
 
258
  return input_image
259
 
260
  def post_processing(self, output_image):
261
- """Processing that needs to be applied after decryption.
262
 
263
  Args:
264
- input_image (np.ndarray): The decrypted image to post-process
265
 
266
  Returns:
267
- input_image (np.ndarray): The post-processed image
268
  """
269
- # Apply a division if needed
270
  if self.divide is not None:
271
  output_image //= self.divide
272
 
@@ -277,7 +285,7 @@ class Filter:
277
  # the same shape conventions.
278
  output_image = output_image.transpose(0, 2, 3, 1).squeeze(0)
279
 
280
- # Grayscaled image needs to be put in RGB format for Gradio display
281
  if self.repeat_out_channels:
282
  output_image = output_image.repeat(3, axis=2)
283
 
 
4
 
5
  import numpy as np
6
  import torch
7
+ from common import AVAILABLE_FILTERS, INPUT_SHAPE
8
  from concrete.numpy.compilation.compiler import Compiler
9
  from torch import nn
10
 
 
63
  class _TorchConv2D(nn.Module):
64
  """Torch model for applying a single 2D convolution operator on images."""
65
 
66
+ def __init__(self, kernel, n_in_channels=3, n_out_channels=3, groups=1, threshold=None):
67
+ """Initialize the filter.
68
 
69
  Args:
70
  kernel (np.ndarray): The convolution kernel to consider.
71
  """
72
  super().__init__()
73
+ self.kernel = torch.tensor(kernel, dtype=torch.int64)
74
  self.n_out_channels = n_out_channels
75
  self.n_in_channels = n_in_channels
76
  self.groups = groups
77
+ self.threshold = threshold
78
 
79
  def forward(self, x):
80
  """Forward pass for filtering the image using a 2D kernel.
 
114
  f"{kernel_shape}"
115
  )
116
 
117
+ # Apply the convolution
118
+ x = nn.functional.conv2d(x, kernel, stride=stride, groups=self.groups)
119
+
120
+ # Subtract a given threshold if given
121
+ if self.threshold is not None:
122
+ x -= self.threshold
123
+
124
+ return x
125
 
126
 
127
  class Filter:
 
143
  )
144
 
145
  self.filter = image_filter
146
+ self.onnx_model = None
147
+ self.fhe_circuit = None
148
  self.divide = None
149
  self.repeat_out_channels = False
150
 
 
166
  # However, since FHE computations require weights to be integers, we first multiply
167
  # these by a factor of 1000. The output image's values are then divided by 1000 in
168
  # post-processing in order to retrieve the correct result
169
+ kernel = [299, 587, 114]
170
 
171
  self.torch_model = _TorchConv2D(kernel, n_out_channels=1, groups=1)
172
 
173
+ # Define the value used when for dividing the output values in post-processing
174
  self.divide = 1000
175
 
176
+ # Indicate that the out_channels will need to be repeated, as Gradio requires all
177
+ # images to have a RGB format, even for grayscaled ones
178
  self.repeat_out_channels = True
179
 
180
  elif image_filter == "blur":
181
+ kernel = np.ones((3, 3))
182
 
183
  self.torch_model = _TorchConv2D(kernel, n_out_channels=3, groups=3)
184
 
185
+ # Define the value used when for dividing the output values in post-processing
186
  self.divide = 9
187
 
188
  elif image_filter == "sharpen":
189
+ kernel = [
190
+ [0, -1, 0],
191
+ [-1, 5, -1],
192
+ [0, -1, 0],
193
+ ]
 
 
194
 
195
  self.torch_model = _TorchConv2D(kernel, n_out_channels=3, groups=3)
196
 
197
  elif image_filter == "ridge detection":
198
+ kernel = [
199
+ [-1, -1, -1],
200
+ [-1, 9, -1],
201
+ [-1, -1, -1],
202
+ ]
203
+
204
+ # Additionally to the convolution operator, the filter will subtract a given threshold
205
+ # value to the result in order to better display the ridges
206
+ self.torch_model = _TorchConv2D(kernel, n_out_channels=1, groups=1, threshold=900)
207
+
208
+ # Indicate that the out_channels will need to be repeated, as Gradio requires all
209
+ # images to have a RGB format, even for grayscaled ones. Ridge detection images are
210
+ # ususally displayed as such
211
  self.repeat_out_channels = True
212
 
213
+ def compile(self, onnx_model=None):
214
+ """Compile the model on a representative inputset.
 
 
 
215
 
216
  Args:
 
217
  onnx_model (onnx.ModelProto): The loaded onnx model to consider. If None, it will be
218
  generated automatically using a NumpyModule. Default to None.
219
  """
220
+ # Generate a random representative set of images used for compilation, following Torch's
221
+ # shape format (batch, in_channels, image_height, image_width)
222
+ np.random.seed(42)
223
  inputset = tuple(
224
+ np.random.randint(0, 255, size=((1, 3) + INPUT_SHAPE), dtype=np.int64) for _ in range(10)
225
  )
226
 
227
  # If no onnx model was given, generate a new one.
 
250
  return self.fhe_circuit
251
 
252
  def pre_processing(self, input_image):
253
+ """Apply pre-processing to the encrypted input images.
254
 
255
  Args:
256
+ input_image (np.ndarray): The image to pre-process.
257
 
258
  Returns:
259
+ input_image (np.ndarray): The pre-processed image.
260
  """
261
  # Reshape the inputs found in inputset. This is done because Torch and Numpy don't follow
262
  # the same shape conventions.
263
+ # Additionally, make sure the input images are made of integers only
264
  input_image = np.expand_dims(input_image.transpose(2, 0, 1), axis=0).astype(np.int64)
265
 
266
  return input_image
267
 
268
  def post_processing(self, output_image):
269
+ """Apply post-processing to the encrypted output images.
270
 
271
  Args:
272
+ input_image (np.ndarray): The decrypted image to post-process.
273
 
274
  Returns:
275
+ input_image (np.ndarray): The post-processed image.
276
  """
277
+ # Divide all values if needed
278
  if self.divide is not None:
279
  output_image //= self.divide
280
 
 
285
  # the same shape conventions.
286
  output_image = output_image.transpose(0, 2, 3, 1).squeeze(0)
287
 
288
+ # Gradio requires all images to follow a RGB format
289
  if self.repeat_out_channels:
290
  output_image = output_image.repeat(3, axis=2)
291
 
filters/black and white/deployment/client.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9a8fae28225c53cc6e184535a8880187626892461a7f0c25afe322dfaa83f678
3
  size 388
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eef269dcec0d548972cc25c3ef9abd8067bd8df8e4a30b53a1b3006575b70baf
3
  size 388
filters/black and white/deployment/server.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4253afb0e46be27e5aa2562c14e7bc402fc7aa39cf01a824f1170c5e46ccf9aa
3
  size 4364
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ca528ad3f3b99b6c69b0bb0e0a4724615a6be7ac2222424b7c2ac48c26e5b95
3
  size 4364
filters/blur/deployment/client.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6ead03bc109e6598c90000f580c2d149f19c4ed968dfa302dfea78c131bdad02
3
  size 391
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce25848e14481bf54e4a52fad3ea178bc78ebf2d62e464839da4de58c5a48d43
3
  size 391
filters/blur/deployment/server.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:deb92cf90d25f2c4552bb8a2fa0eb0f0cbe0fdc549665f2b63a15a78c0b20d72
3
  size 7263
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a058aeab0894ea93e00db344a9e71abeb63c6e8faa8bdb661ae4b304d3eee5c
3
  size 7263
filters/identity/deployment/client.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:05f39bf162b0ffcd9e4105c371180a621ebdc64fce12a751547424e3a6ec0c0b
3
  size 378
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b285786e91816d4f1848968d6737929a90f073d2aabac607b0fe5cd0867f314a
3
  size 378
filters/identity/deployment/server.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9741844b0e166078608310f3812724cf0dd117927457763ecec48e5e510ebc5a
3
  size 2559
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:438384a8517e5ccb354851b9a8baa3ee86af59726d9f1600d98527f0568059b5
3
  size 2559
filters/inverted/deployment/client.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:05f39bf162b0ffcd9e4105c371180a621ebdc64fce12a751547424e3a6ec0c0b
3
  size 378
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b285786e91816d4f1848968d6737929a90f073d2aabac607b0fe5cd0867f314a
3
  size 378
filters/inverted/deployment/server.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e3b413c12bb4d656e739247af1ffb848d57a70a2f57011dd367af66c56cce025
3
  size 4179
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca6086a06f95b349609433162e316ceddf05dfe6ea2b0936492123ff46f417a7
3
  size 4179
filters/ridge detection/deployment/client.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b58c440c56ec350092d730eb2e8d2118225eed501fb8dba5df17b96440fe6a08
3
  size 397
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8fe63b4e3b322a2c4dd1bb742878b2e90c1b6c151dc2af7bb16155fea29a66c
3
  size 397
filters/ridge detection/deployment/server.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:91c225472c2d225d3211b54a1f4bca1a7df54ae6543e6d62e7c6a48601dd9e31
3
- size 4479
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58266bb522e40f8ba7746e1eca6191e7a1c3c385e99b294c759bbbc88f7e6408
3
+ size 5043
filters/ridge detection/server.onnx CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:48821745ed7a9b25b5ba8ae0dc3da35739985bf5dd1dac5b3a9c207adbbf1c45
3
- size 532
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42f9914e64003c33c7eceb639a001ceb4460c8226e0e380cb032741851e41c49
3
+ size 648
filters/rotate/deployment/client.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:05f39bf162b0ffcd9e4105c371180a621ebdc64fce12a751547424e3a6ec0c0b
3
  size 378
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b285786e91816d4f1848968d6737929a90f073d2aabac607b0fe5cd0867f314a
3
  size 378
filters/rotate/deployment/server.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3f4c3cfa73f19125923edca845c70a55eb5a882ddb17fc8f979ed349b845eb6e
3
  size 4431
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ecea959453ffd704efba1c5e22db54e902cc6c3289870ece101793d1479cb347
3
  size 4431
filters/sharpen/deployment/client.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4a13c773025de645be26fbedd7bb9b0926464ad08f992ff089aee90ee58df5f3
3
  size 396
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4df798a79bfc380debbfbc7a9cdaf79a096fe1deb18327f31dc141bea38f8d4e
3
  size 396
filters/sharpen/deployment/server.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:65892df870b0e07b717fb971b85c0e6ed2d0acc345db9fec1eb90497b1092374
3
  size 7311
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:befb2eaff02cc855af745dc82bc8f24ce713e4ff4393e3b635f55b8f82e0ff20
3
  size 7311
generate_dev_files.py CHANGED
@@ -1,11 +1,8 @@
1
  "A script to generate all development files necessary for the image filtering demo."
2
 
3
  import shutil
4
- from pathlib import Path
5
-
6
- import numpy as np
7
  import onnx
8
- from common import AVAILABLE_FILTERS, FILTERS_PATH, INPUT_SHAPE, INPUTSET
9
  from custom_client_server import CustomFHEDev
10
  from filters import Filter
11
 
@@ -17,16 +14,16 @@ for image_filter in AVAILABLE_FILTERS:
17
  # Create the filter instance
18
  filter = Filter(image_filter)
19
 
20
- image_shape = INPUT_SHAPE + (3,)
21
-
22
- # Compile the filter on the inputset
23
- filter.compile(INPUTSET)
24
 
 
25
  filter_path = FILTERS_PATH / image_filter
26
 
 
27
  deployment_path = filter_path / "deployment"
28
 
29
- # Delete the deployment folder and its content if it exist
30
  if deployment_path.is_dir():
31
  shutil.rmtree(deployment_path)
32
 
 
1
  "A script to generate all development files necessary for the image filtering demo."
2
 
3
  import shutil
 
 
 
4
  import onnx
5
+ from common import AVAILABLE_FILTERS, FILTERS_PATH
6
  from custom_client_server import CustomFHEDev
7
  from filters import Filter
8
 
 
14
  # Create the filter instance
15
  filter = Filter(image_filter)
16
 
17
+ # Compile the model on a representative inputset
18
+ filter.compile()
 
 
19
 
20
+ # Define the directory path associated to this filter
21
  filter_path = FILTERS_PATH / image_filter
22
 
23
+ # Define the directory path associated to this filter's deployment files
24
  deployment_path = filter_path / "deployment"
25
 
26
+ # Delete the deployment folder and its content if it already exists
27
  if deployment_path.is_dir():
28
  shutil.rmtree(deployment_path)
29
 
server.py CHANGED
@@ -43,9 +43,12 @@ def send_input(
43
  filter: str = Form(),
44
  files: List[UploadFile] = File(),
45
  ):
 
 
46
  encrypted_image_path = get_server_file_path("encrypted_image", filter, user_id)
47
  evaluation_key_path = get_server_file_path("evaluation_key", filter, user_id)
48
-
 
49
  with encrypted_image_path.open("wb") as encrypted_image, evaluation_key_path.open(
50
  "wb"
51
  ) as evaluation_key:
@@ -58,26 +61,30 @@ def run_fhe(
58
  user_id: str = Form(),
59
  filter: str = Form(),
60
  ):
61
-
 
62
  encrypted_image_path = get_server_file_path("encrypted_image", filter, user_id)
63
  evaluation_key_path = get_server_file_path("evaluation_key", filter, user_id)
64
 
 
65
  with encrypted_image_path.open("rb") as encrypted_image_file, evaluation_key_path.open(
66
  "rb"
67
  ) as evaluation_key_file:
68
  encrypted_image = encrypted_image_file.read()
69
  evaluation_key = evaluation_key_file.read()
70
 
71
- # Load the model
72
- fhe_model = CustomFHEServer(FILTERS_PATH / f"{filter}/deployment")
73
 
74
  # Run the FHE execution
75
  start = time.time()
76
- encrypted_output_image = fhe_model.run(encrypted_image, evaluation_key)
77
  fhe_execution_time = round(time.time() - start, 2)
78
 
 
79
  encrypted_output_path = get_server_file_path("encrypted_output", filter, user_id)
80
 
 
81
  with encrypted_output_path.open("wb") as encrypted_output:
82
  encrypted_output.write(encrypted_output_image)
83
 
@@ -89,8 +96,11 @@ def get_output(
89
  user_id: str = Form(),
90
  filter: str = Form(),
91
  ):
 
 
92
  encrypted_output_path = get_server_file_path("encrypted_output", filter, user_id)
93
 
 
94
  with encrypted_output_path.open("rb") as encrypted_output_file:
95
  encrypted_output = encrypted_output_file.read()
96
 
 
43
  filter: str = Form(),
44
  files: List[UploadFile] = File(),
45
  ):
46
+ """Send the inputs to the server."""
47
+ # Retrieve the encrypted input image and the evaluation key paths
48
  encrypted_image_path = get_server_file_path("encrypted_image", filter, user_id)
49
  evaluation_key_path = get_server_file_path("evaluation_key", filter, user_id)
50
+
51
+ # Write the files using the above paths
52
  with encrypted_image_path.open("wb") as encrypted_image, evaluation_key_path.open(
53
  "wb"
54
  ) as evaluation_key:
 
61
  user_id: str = Form(),
62
  filter: str = Form(),
63
  ):
64
+ """Execute the filter on the encrypted input image using FHE."""
65
+ # Retrieve the encrypted input image and the evaluation key paths
66
  encrypted_image_path = get_server_file_path("encrypted_image", filter, user_id)
67
  evaluation_key_path = get_server_file_path("evaluation_key", filter, user_id)
68
 
69
+ # Read the files using the above paths
70
  with encrypted_image_path.open("rb") as encrypted_image_file, evaluation_key_path.open(
71
  "rb"
72
  ) as evaluation_key_file:
73
  encrypted_image = encrypted_image_file.read()
74
  evaluation_key = evaluation_key_file.read()
75
 
76
+ # Load the FHE server
77
+ fhe_server = CustomFHEServer(FILTERS_PATH / f"{filter}/deployment")
78
 
79
  # Run the FHE execution
80
  start = time.time()
81
+ encrypted_output_image = fhe_server.run(encrypted_image, evaluation_key)
82
  fhe_execution_time = round(time.time() - start, 2)
83
 
84
+ # Retrieve the encrypted output image path
85
  encrypted_output_path = get_server_file_path("encrypted_output", filter, user_id)
86
 
87
+ # Write the file using the above path
88
  with encrypted_output_path.open("wb") as encrypted_output:
89
  encrypted_output.write(encrypted_output_image)
90
 
 
96
  user_id: str = Form(),
97
  filter: str = Form(),
98
  ):
99
+ """Retrieve the encrypted output image."""
100
+ # Retrieve the encrypted output image path
101
  encrypted_output_path = get_server_file_path("encrypted_output", filter, user_id)
102
 
103
+ # Read the file using the above path
104
  with encrypted_output_path.open("rb") as encrypted_output_file:
105
  encrypted_output = encrypted_output_file.read()
106