binoua's picture
chore: adding a system to save key piece by piece
df7187e
raw
history blame contribute delete
No virus
2.2 kB
from typing import Dict, List, Any
import numpy as np
from concrete.ml.deployment import FHEModelServer
def from_json(python_object):
if "__class__" in python_object:
return bytes(python_object["__value__"])
def to_json(python_object):
if isinstance(python_object, bytes):
return {"__class__": "bytes", "__value__": list(python_object)}
raise TypeError(repr(python_object) + " is not JSON serializable")
class EndpointHandler:
def __init__(self, path=""):
# For server
self.fhemodel_server = FHEModelServer(path + "/compiled_model")
# Simulate a database of keys
self.key_database = {}
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
date (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# Get method
method = data.pop("method", data)
if method == "save_key":
# Get keys
evaluation_keys = from_json(data.pop("evaluation_keys", data))
uid = np.random.randint(2**32)
while uid in self.key_database.keys():
uid = np.random.randint(2**32)
self.key_database[uid] = evaluation_keys
return {"uid": uid}
elif method == "append_key":
# Get key piece
evaluation_keys = from_json(data.pop("evaluation_keys", data))
uid = data.pop("uid", data)
self.key_database[uid] += evaluation_keys
return
elif method == "inference":
uid = data.pop("uid", data)
assert uid in self.key_database.keys(), f"{uid} not in DB, {self.key_database.keys()=}"
# Get inputs
encrypted_inputs = from_json(data.pop("encrypted_inputs", data))
# Find key in the database
evaluation_keys = self.key_database[uid]
# Run CML prediction
encrypted_prediction = self.fhemodel_server.run(encrypted_inputs, evaluation_keys)
return to_json(encrypted_prediction)
else:
return