File size: 5,234 Bytes
0908a41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
"Client-server interface custom implementation for filter models."
from concrete import fhe
from filters import Filter
class FHEServer:
"""Server interface run a FHE circuit."""
def __init__(self, path_dir):
"""Initialize the FHE interface.
Args:
path_dir (Path): The path to the directory where the circuit is saved.
"""
self.path_dir = path_dir
# Load the FHE circuit
self.server = fhe.Server.load(self.path_dir / "server.zip")
def run(self, serialized_encrypted_image, serialized_evaluation_keys):
"""Run the filter on the server over an encrypted image.
Args:
serialized_encrypted_image (bytes): The encrypted and serialized image.
serialized_evaluation_keys (bytes): The serialized evaluation keys.
Returns:
bytes: The filter's output.
"""
# Deserialize the encrypted input image and the evaluation keys
encrypted_image = fhe.Value.deserialize(serialized_encrypted_image)
evaluation_keys = fhe.EvaluationKeys.deserialize(serialized_evaluation_keys)
# Execute the filter in FHE
encrypted_output = self.server.run(encrypted_image, evaluation_keys=evaluation_keys)
# Serialize the encrypted output image
serialized_encrypted_output = encrypted_output.serialize()
return serialized_encrypted_output
class FHEDev:
"""Development interface to save and load the filter."""
def __init__(self, filter, path_dir):
"""Initialize the FHE interface.
Args:
filter (Filter): The filter to use in the FHE interface.
path_dir (str): The path to the directory where the circuit is saved.
"""
self.filter = filter
self.path_dir = path_dir
self.path_dir.mkdir(parents=True, exist_ok=True)
def save(self):
"""Export all needed artifacts for the client and server interfaces."""
assert self.filter.fhe_circuit is not None, (
"The model must be compiled before saving it."
)
# Save the circuit for the server, using the via_mlir in order to handle cross-platform
# execution
path_circuit_server = self.path_dir / "server.zip"
self.filter.fhe_circuit.server.save(path_circuit_server, via_mlir=True)
# Save the circuit for the client
path_circuit_client = self.path_dir / "client.zip"
self.filter.fhe_circuit.client.save(path_circuit_client)
class FHEClient:
"""Client interface to encrypt and decrypt FHE data associated to a Filter."""
def __init__(self, path_dir, filter_name, key_dir=None):
"""Initialize the FHE interface.
Args:
path_dir (Path): The path to the directory where the circuit is saved.
filter_name (str): The filter's name to consider.
key_dir (Path): The path to the directory where the keys are stored. Default to None.
"""
self.path_dir = path_dir
self.key_dir = key_dir
# If path_dir does not exist raise
assert path_dir.exists(), f"{path_dir} does not exist. Please specify a valid path."
# Load the client
self.client = fhe.Client.load(self.path_dir / "client.zip", self.key_dir)
# Instantiate the filter
self.filter = Filter(filter_name)
def generate_private_and_evaluation_keys(self, force=False):
"""Generate the private and evaluation keys.
Args:
force (bool): If True, regenerate the keys even if they already exist.
"""
self.client.keygen(force)
def get_serialized_evaluation_keys(self):
"""Get the serialized evaluation keys.
Returns:
bytes: The evaluation keys.
"""
return self.client.evaluation_keys.serialize()
def encrypt_serialize(self, input_image):
"""Encrypt and serialize the input image in the clear.
Args:
input_image (numpy.ndarray): The image to encrypt and serialize.
Returns:
bytes: The pre-processed, encrypted and serialized image.
"""
# Encrypt the image
encrypted_image = self.client.encrypt(input_image)
# Serialize the encrypted image to be sent to the server
serialized_encrypted_image = encrypted_image.serialize()
return serialized_encrypted_image
def deserialize_decrypt_post_process(self, serialized_encrypted_output_image):
"""Deserialize, decrypt and post-process the output image in the clear.
Args:
serialized_encrypted_output_image (bytes): The serialized and encrypted output image.
Returns:
numpy.ndarray: The decrypted, deserialized and post-processed image.
"""
# Deserialize the encrypted image
encrypted_output_image = fhe.Value.deserialize(
serialized_encrypted_output_image
)
# Decrypt the image
output_image = self.client.decrypt(encrypted_output_image)
# Post-process the image
post_processed_output_image = self.filter.post_processing(output_image)
return post_processed_output_image
|