File size: 4,265 Bytes
d831908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Server that will listen for GET requests from the client."""
from fastapi import FastAPI
from joblib import load
from concrete.ml.deployment import FHEModelServer
from pydantic import BaseModel
import base64
from pathlib import Path

current_dir = Path(__file__).parent

# Load the model
fhe_model_HLM = FHEModelServer(
    Path.joinpath(current_dir, "deployment/deployment_0")
)
fhe_model_MDR1MDCK = FHEModelServer(
    Path.joinpath(current_dir, "deployment/deployment_1")
)
fhe_model_SOLUBILITY = FHEModelServer(
    Path.joinpath(current_dir, "deployment/deployment_2")
)
fhe_model_PROTEIN_BINDING_HUMAN = FHEModelServer(
    Path.joinpath(current_dir, "deployment/deployment_3")
)
fhe_model_PROTEIN_BINDING_RAT = FHEModelServer(
    Path.joinpath(current_dir, "deployment/deployment_4")
)
fhe_model_RLM_CLint = FHEModelServer(
    Path.joinpath(current_dir, "deployment/deployment_5")
)


class PredictRequest(BaseModel):
    evaluation_key: str
    encrypted_encoding: str


# Initialize an instance of FastAPI
app = FastAPI()


# Define the default route
@app.get("/")
def root():
    return {"message": "Welcome to Your Molecular Property prediction FHE Server!"}


@app.post("/predict_HLM")
def predict_HLM(query: PredictRequest):
    encrypted_encoding = base64.b64decode(query.encrypted_encoding)
    evaluation_key = base64.b64decode(query.evaluation_key)
    prediction = fhe_model_HLM.run(encrypted_encoding, evaluation_key)

    # Encode base64 the prediction
    encoded_prediction = base64.b64encode(prediction).decode()
    return {"encrypted_prediction": encoded_prediction}

@app.post("/predict_MDR1MDCK")
def predict_MDR1MDCK(query: PredictRequest):
    encrypted_encoding = base64.b64decode(query.encrypted_encoding)
    evaluation_key = base64.b64decode(query.evaluation_key)
    prediction = fhe_model_MDR1MDCK.run(encrypted_encoding, evaluation_key)

    # Encode base64 the prediction
    encoded_prediction = base64.b64encode(prediction).decode()
    return {"encrypted_prediction": encoded_prediction}

@app.post("/predict_SOLUBILITY")
def predict_SOLUBILITY(query: PredictRequest):
    encrypted_encoding = base64.b64decode(query.encrypted_encoding)
    evaluation_key = base64.b64decode(query.evaluation_key)
    prediction = fhe_model_SOLUBILITY.run(encrypted_encoding, evaluation_key)

    # Encode base64 the prediction
    encoded_prediction = base64.b64encode(prediction).decode()
    return {"encrypted_prediction": encoded_prediction}

@app.post("/predict_PROTEIN_BINDING_HUMAN")
def predict_PROTEIN_BINDING_HUMAN(query: PredictRequest):
    encrypted_encoding = base64.b64decode(query.encrypted_encoding)
    evaluation_key = base64.b64decode(query.evaluation_key)
    prediction = fhe_model_PROTEIN_BINDING_HUMAN.run(encrypted_encoding, evaluation_key)

    # Encode base64 the prediction
    encoded_prediction = base64.b64encode(prediction).decode()
    return {"encrypted_prediction": encoded_prediction}


@app.post("/predict_PROTEIN_BINDING_RAT")
def predict_PROTEIN_BINDING_RAT(query: PredictRequest):
    encrypted_encoding = base64.b64decode(query.encrypted_encoding)
    evaluation_key = base64.b64decode(query.evaluation_key)
    prediction = fhe_model_PROTEIN_BINDING_RAT.run(encrypted_encoding, evaluation_key)

    # Encode base64 the prediction
    encoded_prediction = base64.b64encode(prediction).decode()
    return {"encrypted_prediction": encoded_prediction}

def predict_RLM_CLint(query: PredictRequest):
    encrypted_encoding = base64.b64decode(query.encrypted_encoding)
    evaluation_key = base64.b64decode(query.evaluation_key)
    prediction = fhe_model_RLM_CLint.run(encrypted_encoding, evaluation_key)

    # Encode base64 the prediction
    encoded_prediction = base64.b64encode(prediction).decode()
    return {"encrypted_prediction": encoded_prediction}

@app.post("/predict_RLM_CLint")
def predict_RLM_CLint(query: PredictRequest):
    encrypted_encoding = base64.b64decode(query.encrypted_encoding)
    evaluation_key = base64.b64decode(query.evaluation_key)
    prediction = fhe_model_RLM_CLint.run(encrypted_encoding, evaluation_key)

    # Encode base64 the prediction
    encoded_prediction = base64.b64encode(prediction).decode()
    return {"encrypted_prediction": encoded_prediction}