File size: 5,433 Bytes
127130c fbd9a75 127130c 4040d43 127130c fbd9a75 127130c 4040d43 127130c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
"Client-server interface custom implementation for seizure detection models."
from common import SEIZURE_DETECTION_MODEL_PATH
from concrete import fhe
from seizure_detection import SeizureDetector
class FHEServer:
"""Server interface to run a FHE circuit for seizure detection."""
def __init__(self, model_path):
"""Initialize the FHE interface.
Args:
model_path (Path): The path to the directory where the circuit is saved.
"""
self.model_path = model_path
# Load the FHE circuit
self.server = fhe.Server.load(self.model_path / "server.zip")
def run(self, serialized_encrypted_image, serialized_evaluation_keys):
"""Run seizure detection on the server over an encrypted image.
Args:
serialized_encrypted_image (bytes): The encrypted and serialized image.
serialized_evaluation_keys (bytes): The serialized evaluation keys.
Returns:
bytes: The encrypted boolean output indicating seizure detection.
"""
# Deserialize the encrypted input image and the evaluation keys
encrypted_image = fhe.Value.deserialize(serialized_encrypted_image)
evaluation_keys = fhe.EvaluationKeys.deserialize(serialized_evaluation_keys)
# Execute the seizure detection in FHE
encrypted_output = self.server.run(encrypted_image, evaluation_keys=evaluation_keys)
# Serialize the encrypted output
serialized_encrypted_output = encrypted_output.serialize()
return serialized_encrypted_output
class FHEDev:
"""Development interface to save and load the seizure detection model."""
def __init__(self, seizure_detector, model_path):
"""Initialize the FHE interface.
Args:
seizure_detector (SeizureDetector): The seizure detection model to use in the FHE interface.
model_path (str): The path to the directory where the circuit is saved.
"""
self.seizure_detector = seizure_detector
self.model_path = model_path
self.model_path.mkdir(parents=True, exist_ok=True)
def save(self):
"""Export all needed artifacts for the client and server interfaces."""
assert self.seizure_detector.fhe_circuit is not None, (
"The model must be compiled before saving it."
)
# Save the circuit for the server, using the via_mlir in order to handle cross-platform
# execution
path_circuit_server = self.model_path / "server.zip"
self.seizure_detector.fhe_circuit.server.save(path_circuit_server, via_mlir=True)
# Save the circuit for the client
path_circuit_client = self.model_path / "client.zip"
self.seizure_detector.fhe_circuit.client.save(path_circuit_client)
class FHEClient:
"""Client interface to encrypt and decrypt FHE data associated to a SeizureDetector."""
def __init__(self, key_dir=None):
"""Initialize the FHE interface.
Args:
model_path (Path): The path to the directory where the circuit is saved.
key_dir (Path): The path to the directory where the keys are stored. Default to None.
"""
self.model_path = SEIZURE_DETECTION_MODEL_PATH
self.key_dir = key_dir
# If model_path does not exist raise
assert self.model_path.exists(), f"{self.model_path} does not exist. Please specify a valid path."
# Load the client
self.client = fhe.Client.load(self.model_path / "client.zip", self.key_dir)
# Instantiate the seizure detector
self.seizure_detector = SeizureDetector()
def generate_private_and_evaluation_keys(self, force=False):
"""Generate the private and evaluation keys.
Args:
force (bool): If True, regenerate the keys even if they already exist.
"""
self.client.keygen(force)
def get_serialized_evaluation_keys(self):
"""Get the serialized evaluation keys.
Returns:
bytes: The evaluation keys.
"""
return self.client.evaluation_keys.serialize()
def encrypt_serialize(self, input_image):
"""Encrypt and serialize the input image in the clear.
Args:
input_image (numpy.ndarray): The image to encrypt and serialize.
Returns:
bytes: The pre-processed, encrypted and serialized image.
"""
# Encrypt the image
encrypted_image = self.client.encrypt(input_image)
# Serialize the encrypted image to be sent to the server
serialized_encrypted_image = encrypted_image.serialize()
return serialized_encrypted_image
def deserialize_decrypt_post_process(self, serialized_encrypted_output):
"""Deserialize, decrypt and post-process the output in the clear.
Args:
serialized_encrypted_output (bytes): The serialized and encrypted output.
Returns:
bool: The decrypted and deserialized boolean indicating seizure detection.
"""
# Deserialize the encrypted output
encrypted_output = fhe.Value.deserialize(serialized_encrypted_output)
# Decrypt the output
output = self.client.decrypt(encrypted_output)
# Post-process the output (if needed)
seizure_detected = self.seizure_detector.post_processing(output)
return seizure_detected
|