80cols commited on
Commit
949aa01
·
verified ·
1 Parent(s): 17f70ea

Create predictor.py

Browse files
Files changed (1) hide show
  1. predictor.py +135 -0
predictor.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import joblib
3
+ import numpy as np
4
+ from concrete.ml.deployment import FHEModelClient, FHEModelServer
5
+ import logging
6
+ import gradio as gr
7
+
8
+ # Configure logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ key_already_generated_condition = False
11
+ encrypted_data = None
12
+ encrypted_prediction = None
13
+
14
+ # Paths to required files
15
+ SCALER_PATH = os.path.join("models", "scaler.pkl")
16
+ FHE_FILES_PATH = os.path.join("models", "fhe_files")
17
+
18
+ # Load the scaler
19
+ try:
20
+ scaler = joblib.load(SCALER_PATH)
21
+ logging.info("Scaler loaded successfully.")
22
+ except FileNotFoundError:
23
+ logging.error(f"Error: The file scaler.pkl is missing at {SCALER_PATH}.")
24
+ raise
25
+
26
+ # Initialize the FHE client and server
27
+ try:
28
+ client = FHEModelClient(path_dir=FHE_FILES_PATH, key_dir=FHE_FILES_PATH)
29
+ server = FHEModelServer(path_dir=FHE_FILES_PATH)
30
+ server.load()
31
+ logging.info("FHE Client and Server initialized successfully.")
32
+ except FileNotFoundError:
33
+ logging.error(f"Error: The FHE files (client.zip, server.zip) are missing in {FHE_FILES_PATH}.")
34
+ raise
35
+
36
+ # Load evaluation keys
37
+ evaluation_keys = client.get_serialized_evaluation_keys()
38
+
39
+ def predict():
40
+ """
41
+ Perform a local prediction using the compiled FHE model.
42
+ Returns:
43
+ str: The prediction result.
44
+ str: A message indicating the status of the prediction.
45
+ """
46
+ global encrypted_data, encrypted_prediction
47
+ if encrypted_data is None:
48
+ return None, gr.update(value="No encrypted data to predict. Please provide encrypted data ❌")
49
+ try:
50
+ # Execute the model locally on encrypted data
51
+ encrypted_prediction = server.run(
52
+ encrypted_data, serialized_evaluation_keys=evaluation_keys
53
+ )
54
+ logging.info(f"Encrypted Prediction: {encrypted_prediction}")
55
+ return encrypted_prediction.hex(), gr.update(value="FHE evaluation is done. ✅")
56
+
57
+ except Exception as e:
58
+ logging.error(f"Error during prediction: {e}")
59
+ return None, gr.update(value="No encrypted data to predict. Please provide encrypted data ❌")
60
+
61
+ def decrypt_prediction():
62
+ """
63
+ Decrypt and interpret the prediction result.
64
+ Returns:
65
+ str: The interpreted prediction result.
66
+ """
67
+ global encrypted_prediction
68
+ if encrypted_prediction is None:
69
+ return "No prediction to decrypt. Please make a prediction first. ❌", "No prediction to decrypt. Please make a prediction first. ❌"
70
+ try:
71
+ # Decrypt the prediction result
72
+ decrypted_prediction = client.deserialize_decrypt_dequantize(encrypted_prediction)
73
+ logging.info(f"Decrypted Prediction: {decrypted_prediction}")
74
+
75
+ # Interpret the prediction
76
+ binary_prediction = int(np.argmax(decrypted_prediction))
77
+ # Ensure the prediction is a flat array
78
+ if isinstance(decrypted_prediction, np.ndarray) and decrypted_prediction.ndim > 1:
79
+ decrypted_prediction = decrypted_prediction.flatten()
80
+
81
+ # Generate the HTML for the percentage bar
82
+ bar_html = f"""
83
+ <div style="width: 100%; background-color: lightgray; border-radius: 5px; overflow: hidden; display: flex;">
84
+ <div style="width: {decrypted_prediction[0] * 100}%; background-color: green; color: white; text-align: center; padding: 5px 0;">
85
+ {decrypted_prediction[0] * 100:.1f}% Non-Fraud
86
+ </div>
87
+ <div style="width: {decrypted_prediction[1] * 100}%; background-color: red; color: white; text-align: center; padding: 5px 0;">
88
+ {decrypted_prediction[1] * 100:.1f}% Fraud
89
+ </div>
90
+ </div>
91
+ """
92
+ return "⚠️ Fraudulent ⚠️" if binary_prediction == 1 else "😊 Non-fraudulent 😊", gr.update(value="Decryption successful ✅"), bar_html
93
+
94
+ except Exception as e:
95
+ logging.error(f"Error during prediction: {e}")
96
+ return "Error during prediction❌", "Error during prediction❌","Error during prediction❌"
97
+
98
+ def key_already_generated():
99
+ """
100
+ Check if the evaluation keys have already been generated.
101
+ Returns:
102
+ bool: True if the evaluation keys have already been generated, False otherwise.
103
+ """
104
+ global key_already_generated_condition
105
+ if evaluation_keys:
106
+ key_already_generated_condition = True
107
+ return True
108
+ return False
109
+
110
+ def pre_process_encrypt_send_purchase(*inputs):
111
+ """
112
+ Pre-processes, encrypts, and sends the purchase data for prediction.
113
+ Args:
114
+ *inputs: Variable number of input arguments.
115
+ Returns:
116
+ (str): A short representation of the encrypted input to send in hex.
117
+ """
118
+ global key_already_generated_condition, encrypted_data
119
+ if key_already_generated_condition == False:
120
+ return None, gr.update(value="Generate your key before. ❌")
121
+ try:
122
+ key_already_generated_condition = True
123
+ logging.info(f"Input Data: {inputs}")
124
+
125
+ # Scale the input data
126
+ scaled_data = scaler.transform([list(inputs)])
127
+ logging.info(f"Scaled Data: {scaled_data}")
128
+
129
+ # Encrypt the scaled data
130
+ encrypted_data = client.quantize_encrypt_serialize(scaled_data)
131
+ logging.info("Data encrypted successfully.")
132
+ return encrypted_data.hex(), gr.update(value="Inputs are encrypted and sent to server. ✅")
133
+ except Exception as e:
134
+ logging.error(f"Error during pre-processing: {e}")
135
+ return "Error during pre-processing"