Riad-Quadratic
commited on
Commit
•
1d1124e
1
Parent(s):
d75dabd
REFACTOR: refactor code + add doc
Browse files- app.py +43 -163
- network.py +27 -1
- utils.py +193 -20
app.py
CHANGED
@@ -1,241 +1,121 @@
|
|
1 |
-
from concrete import fhe
|
2 |
-
import folium
|
3 |
-
import geopandas
|
4 |
-
from shapely.geometry import Point
|
5 |
import streamlit as st
|
6 |
-
from
|
7 |
-
|
8 |
import config
|
|
|
|
|
|
|
9 |
from network import get_frames
|
10 |
-
from utils import generate_path, set_up_server, set_up_client, display_encrypted
|
11 |
-
|
12 |
|
13 |
-
st.set_page_config(layout="wide")
|
14 |
ways = geopandas.read_file(config.roads_filepath)
|
15 |
nodes, _, rel = get_frames(ways)
|
16 |
|
17 |
server = set_up_server()
|
18 |
client = set_up_client(server.client_specs.serialize())
|
19 |
|
20 |
-
|
21 |
-
color="red",
|
22 |
-
marker_kwds=dict(radius=5, fill=True, name='node_id'),
|
23 |
-
tooltip="node_id",
|
24 |
-
tooltip_kwds=dict(labels=False),
|
25 |
-
zoom_control=False,
|
26 |
-
)
|
27 |
-
|
28 |
-
|
29 |
-
c1, c2, c3 = st.columns([1, 3, 1])
|
30 |
|
31 |
with c1:
|
32 |
-
|
33 |
|
34 |
with c3:
|
35 |
-
|
36 |
-
|
37 |
|
|
|
38 |
if 'evaluation_key' not in st.session_state:
|
39 |
with c1:
|
40 |
if st.button('Generate keys'):
|
41 |
with st.spinner('Generating keys'):
|
42 |
client.keys.load_if_exists_generate_and_save_otherwise(config.keys_filepath)
|
43 |
-
# client.keys.save(config.keys_filepath)
|
44 |
st.session_state['evaluation_key'] = client.evaluation_keys.serialize()
|
|
|
|
|
|
|
45 |
st.rerun()
|
46 |
else:
|
47 |
-
|
48 |
-
st.write("Encryption/decryption keys and evaluation keys are generated.")
|
49 |
-
st.write("The evaluation key is sent to the server.")
|
50 |
-
with c3:
|
51 |
-
st.write(f"Evaluation key: {display_encrypted(st.session_state['evaluation_key'])}")
|
52 |
-
|
53 |
-
|
54 |
if 'origin' not in st.session_state :
|
55 |
with c1:
|
56 |
st.write("Select the origin on the map")
|
57 |
-
|
58 |
with c2:
|
59 |
-
st_data_origin =
|
60 |
-
|
61 |
-
with c3:
|
62 |
-
st.write("")
|
63 |
|
64 |
if st_data_origin["last_object_clicked"]:
|
65 |
origin = Point(st_data_origin["last_object_clicked"]["lng"], st_data_origin["last_object_clicked"]["lat"])
|
|
|
66 |
origin_node = nodes[nodes['geometry'] == origin]['node_id'].values[0]
|
67 |
st.session_state['origin'] = origin
|
68 |
st.session_state['origin_node'] = origin_node
|
69 |
-
|
70 |
st.rerun()
|
71 |
|
|
|
72 |
if 'origin' in st.session_state and 'destination' not in st.session_state:
|
73 |
-
origin = st.session_state['origin']
|
74 |
-
folium.Marker([origin.y, origin.x], popup="Origin", tooltip="Origin").add_to(m)
|
75 |
origin_node = st.session_state['origin_node']
|
76 |
with c1:
|
77 |
-
st.write(f"Selected origin is node number: {origin_node}")
|
78 |
st.write("Select the destination on the map")
|
79 |
with c2:
|
80 |
-
st_data_destination =
|
81 |
-
with c3:
|
82 |
-
st.write("")
|
83 |
-
st.write("")
|
84 |
|
85 |
if st_data_destination["last_object_clicked"]:
|
86 |
destination = Point(st_data_destination["last_object_clicked"]["lng"], st_data_destination["last_object_clicked"]["lat"])
|
|
|
|
|
87 |
st.session_state['destination'] = destination
|
88 |
-
st.session_state['destination_node'] =
|
89 |
-
|
90 |
st.rerun()
|
91 |
|
92 |
-
|
93 |
-
|
94 |
-
folium.Marker([origin.y, origin.x], popup="Origin", tooltip="Origin").add_to(m)
|
95 |
-
destination = st.session_state['destination']
|
96 |
-
folium.Marker([destination.y, destination.x], popup="Destination", tooltip="Destination").add_to(m)
|
97 |
-
# breakpoint()
|
98 |
origin_node = st.session_state['origin_node']
|
99 |
destination_node = st.session_state['destination_node']
|
100 |
with c1:
|
101 |
-
st.write(f"Selected origin is node number: {origin_node}")
|
102 |
-
st.write(f"Selected destination is node number: {destination_node}")
|
103 |
if st.button('Encrypt and send inputs'):
|
104 |
with st.spinner('Encrypting inputs'):
|
105 |
client.keys.load(config.keys_filepath)
|
106 |
origin, destination = client.encrypt(origin_node, destination_node)
|
107 |
st.session_state['encrypted_origin'] = origin.serialize()
|
108 |
st.session_state['encrypted_destination'] = destination.serialize()
|
|
|
|
|
109 |
st.rerun()
|
110 |
with c2:
|
111 |
-
st_data_final =
|
112 |
-
with c3:
|
113 |
-
st.write("")
|
114 |
-
st.write("")
|
115 |
|
|
|
116 |
if 'encrypted_origin' in st.session_state and 'encrypted_shortest_path' not in st.session_state:
|
117 |
-
origin = st.session_state['origin']
|
118 |
-
folium.Marker([origin.y, origin.x], popup="Origin", tooltip="Origin").add_to(m)
|
119 |
-
destination = st.session_state['destination']
|
120 |
-
folium.Marker([destination.y, destination.x], popup="Destination", tooltip="Destination").add_to(m)
|
121 |
-
|
122 |
origin_node = st.session_state['origin_node']
|
123 |
destination_node = st.session_state['destination_node']
|
124 |
-
with c1:
|
125 |
-
st.write(f"Selected origin is node number: {origin_node}")
|
126 |
-
st.write(f"Selected destination is node number: {destination_node}")
|
127 |
with c2:
|
128 |
-
|
129 |
-
folium.Marker([origin.y, origin.x], popup="Origin", tooltip="Origin").add_to(m)
|
130 |
-
destination = st.session_state['destination']
|
131 |
-
folium.Marker([destination.y, destination.x], popup="Destination", tooltip="Destination").add_to(m)
|
132 |
-
st_data_final = st_folium(m, width=725, key="destination", returned_objects=[])
|
133 |
with c3:
|
134 |
-
st.
|
135 |
-
st.write("")
|
136 |
-
st.write(f"Received origin: {display_encrypted(st.session_state['encrypted_origin'])}")
|
137 |
-
st.write(f"Received destination: {display_encrypted(st.session_state['encrypted_destination'])}")
|
138 |
-
if st.button('Compute shortest path'):
|
139 |
with st.spinner('Computing'):
|
140 |
-
|
141 |
-
deserialized_destination = fhe.Value.deserialize(st.session_state['encrypted_destination'])
|
142 |
-
deserialized_evaluation_keys = fhe.EvaluationKeys.deserialize(st.session_state['evaluation_key'])
|
143 |
-
client.keys.load_if_exists_generate_and_save_otherwise(config.keys_filepath)
|
144 |
-
origin = st.session_state['origin_node']
|
145 |
-
destination = st.session_state['destination_node']
|
146 |
-
path = [origin, ]
|
147 |
-
encrypted_path = [st.session_state['encrypted_origin'], ]
|
148 |
-
o, d = deserialized_origin, deserialized_destination
|
149 |
-
for _ in range(nodes.shape[0] - 1):
|
150 |
-
# Careful: breaking early could lead to information leak
|
151 |
-
if origin == destination:
|
152 |
-
break
|
153 |
-
o = server.run(o, d, evaluation_keys=client.evaluation_keys)
|
154 |
-
origin = client.decrypt(o)
|
155 |
-
encrypted_path.append(o.serialize())
|
156 |
-
path.append(origin)
|
157 |
-
|
158 |
st.session_state['encrypted_shortest_path'] = encrypted_path
|
159 |
-
|
|
|
|
|
|
|
160 |
st.rerun()
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
folium.Marker([origin.y, origin.x], popup="Origin", tooltip="Origin").add_to(m)
|
165 |
-
destination = st.session_state['destination']
|
166 |
-
folium.Marker([destination.y, destination.x], popup="Destination", tooltip="Destination").add_to(m)
|
167 |
-
|
168 |
-
origin_node = st.session_state['origin_node']
|
169 |
-
destination_node = st.session_state['destination_node']
|
170 |
with c1:
|
171 |
-
st.write(f"Selected origin is node number: {origin_node}")
|
172 |
-
st.write(f"Selected destination is node number: {destination_node}")
|
173 |
-
st.write("Received path is:")
|
174 |
-
st.write(f"{display_encrypted(st.session_state['encrypted_shortest_path'][0])}")
|
175 |
-
st.write("...")
|
176 |
-
st.write(f"{display_encrypted(st.session_state['encrypted_shortest_path'][-1])}")
|
177 |
-
|
178 |
if st.button('Decrypt and show shortest path'):
|
179 |
with st.spinner('Computing'):
|
180 |
-
client
|
181 |
-
path = []
|
182 |
-
for enc_value in st.session_state['encrypted_shortest_path']:
|
183 |
-
deserialized_result = fhe.Value.deserialize(enc_value)
|
184 |
-
path.append(client.decrypt(deserialized_result))
|
185 |
-
st.session_state['decrypted_result'] = path
|
186 |
st.rerun()
|
187 |
with c2:
|
|
|
188 |
|
189 |
-
|
190 |
-
folium.Marker([origin.y, origin.x], popup="Origin", tooltip="Origin").add_to(m)
|
191 |
-
|
192 |
-
destination = st.session_state['destination']
|
193 |
-
folium.Marker([destination.y, destination.x], popup="Destination", tooltip="Destination").add_to(m)
|
194 |
-
st_data_final = st_folium(m, width=725, key="destination", returned_objects=[])
|
195 |
-
with c3:
|
196 |
-
st.write("")
|
197 |
-
st.write("")
|
198 |
-
st.write(f"Received origin: {display_encrypted(st.session_state['encrypted_origin'])}")
|
199 |
-
st.write(f"Received destination: {display_encrypted(st.session_state['encrypted_destination'])}")
|
200 |
-
# st.write(f"Next node is {display_encrypted(st.session_state['encrypted_shortest_path'])} and is sent to client")
|
201 |
if 'decrypted_result' in st.session_state :
|
202 |
-
origin = st.session_state['origin']
|
203 |
-
folium.Marker([origin.y, origin.x], popup="Origin", tooltip="Origin").add_to(m)
|
204 |
-
destination = st.session_state['destination']
|
205 |
-
folium.Marker([destination.y, destination.x], popup="Destination", tooltip="Destination").add_to(m)
|
206 |
-
|
207 |
-
origin_node = st.session_state['origin_node']
|
208 |
-
destination_node = st.session_state['destination_node']
|
209 |
with c1:
|
210 |
-
st.write(f"Selected origin is node number: {origin_node}")
|
211 |
-
st.write(f"Selected destination is node number: {destination_node}")
|
212 |
-
st.write("Received path is:")
|
213 |
-
st.write(f"{display_encrypted(st.session_state['encrypted_shortest_path'][0])}")
|
214 |
-
st.write("...")
|
215 |
-
st.write(f"{display_encrypted(st.session_state['encrypted_shortest_path'][-1])}")
|
216 |
st.write(f"Decrypted result is: {st.session_state['decrypted_result']}")
|
217 |
-
|
218 |
-
for key in st.session_state.keys():
|
219 |
-
del st.session_state[key]
|
220 |
-
st.rerun()
|
221 |
with c2:
|
222 |
-
origin = st.session_state['origin']
|
223 |
-
folium.Marker([origin.y, origin.x], popup="Origin", tooltip="Origin").add_to(m)
|
224 |
-
destination = st.session_state['destination']
|
225 |
-
folium.Marker([destination.y, destination.x], popup="Destination", tooltip="Destination").add_to(m)
|
226 |
shortest_path_list = st.session_state['decrypted_result']
|
227 |
final_result = generate_path(shortest_path_list, rel)
|
228 |
-
final_result
|
229 |
-
m=m,
|
230 |
-
color="green",
|
231 |
-
style_kwds = {"weight":5},
|
232 |
-
tooltip="name",
|
233 |
-
popup=["name"],
|
234 |
-
name="Quadratic-Paris",
|
235 |
-
)
|
236 |
-
st_data_final = st_folium(m, width=725, key="destination", returned_objects=[])
|
237 |
-
with c3:
|
238 |
-
st.write("")
|
239 |
-
st.write("")
|
240 |
-
st.write(f"Received origin: {display_encrypted(st.session_state['encrypted_origin'])}")
|
241 |
-
st.write(f"Received destination: {display_encrypted(st.session_state['encrypted_destination'])}")
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
from shapely.geometry import Point
|
3 |
+
import geopandas
|
4 |
import config
|
5 |
+
from utils import generate_path, set_up_server, set_up_client, display_encrypted, add_marker,\
|
6 |
+
display_map, init_session, add_to_server_side, add_to_client_side, display_client_side,\
|
7 |
+
display_server_side, restart_session, compute_shortest_path, decrypt_shortest_path
|
8 |
from network import get_frames
|
|
|
|
|
9 |
|
|
|
10 |
ways = geopandas.read_file(config.roads_filepath)
|
11 |
nodes, _, rel = get_frames(ways)
|
12 |
|
13 |
server = set_up_server()
|
14 |
client = set_up_client(server.client_specs.serialize())
|
15 |
|
16 |
+
c1, c2, c3 = init_session()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
with c1:
|
19 |
+
display_client_side()
|
20 |
|
21 |
with c3:
|
22 |
+
display_server_side()
|
|
|
23 |
|
24 |
+
# keys generation view
|
25 |
if 'evaluation_key' not in st.session_state:
|
26 |
with c1:
|
27 |
if st.button('Generate keys'):
|
28 |
with st.spinner('Generating keys'):
|
29 |
client.keys.load_if_exists_generate_and_save_otherwise(config.keys_filepath)
|
|
|
30 |
st.session_state['evaluation_key'] = client.evaluation_keys.serialize()
|
31 |
+
add_to_client_side("Encryption/decryption keys and evaluation keys are generated.")
|
32 |
+
add_to_client_side("The evaluation key is sent to the server.")
|
33 |
+
add_to_server_side(f"Evaluation key: {display_encrypted(st.session_state['evaluation_key'])}")
|
34 |
st.rerun()
|
35 |
else:
|
36 |
+
# Origin selection view
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
if 'origin' not in st.session_state :
|
38 |
with c1:
|
39 |
st.write("Select the origin on the map")
|
|
|
40 |
with c2:
|
41 |
+
st_data_origin = display_map(nodes, returned_objects=["last_object_clicked"])
|
|
|
|
|
|
|
42 |
|
43 |
if st_data_origin["last_object_clicked"]:
|
44 |
origin = Point(st_data_origin["last_object_clicked"]["lng"], st_data_origin["last_object_clicked"]["lat"])
|
45 |
+
add_marker(origin, 'origin')
|
46 |
origin_node = nodes[nodes['geometry'] == origin]['node_id'].values[0]
|
47 |
st.session_state['origin'] = origin
|
48 |
st.session_state['origin_node'] = origin_node
|
49 |
+
add_to_client_side(f"Selected origin is node number: {origin_node}")
|
50 |
st.rerun()
|
51 |
|
52 |
+
# Destination selection view
|
53 |
if 'origin' in st.session_state and 'destination' not in st.session_state:
|
|
|
|
|
54 |
origin_node = st.session_state['origin_node']
|
55 |
with c1:
|
|
|
56 |
st.write("Select the destination on the map")
|
57 |
with c2:
|
58 |
+
st_data_destination = display_map(nodes, returned_objects=["last_object_clicked"])
|
|
|
|
|
|
|
59 |
|
60 |
if st_data_destination["last_object_clicked"]:
|
61 |
destination = Point(st_data_destination["last_object_clicked"]["lng"], st_data_destination["last_object_clicked"]["lat"])
|
62 |
+
add_marker(destination, 'destination')
|
63 |
+
destination_node = nodes[nodes['geometry'] == destination]['node_id'].values[0]
|
64 |
st.session_state['destination'] = destination
|
65 |
+
st.session_state['destination_node'] = destination_node
|
66 |
+
add_to_client_side(f"Selected destination is node number: {destination_node}")
|
67 |
st.rerun()
|
68 |
|
69 |
+
# Origin/Destination encryption view
|
70 |
+
if 'destination' in st.session_state and 'encrypted_origin' not in st.session_state :
|
|
|
|
|
|
|
|
|
71 |
origin_node = st.session_state['origin_node']
|
72 |
destination_node = st.session_state['destination_node']
|
73 |
with c1:
|
|
|
|
|
74 |
if st.button('Encrypt and send inputs'):
|
75 |
with st.spinner('Encrypting inputs'):
|
76 |
client.keys.load(config.keys_filepath)
|
77 |
origin, destination = client.encrypt(origin_node, destination_node)
|
78 |
st.session_state['encrypted_origin'] = origin.serialize()
|
79 |
st.session_state['encrypted_destination'] = destination.serialize()
|
80 |
+
add_to_server_side(f"Received origin: {display_encrypted(st.session_state['encrypted_origin'])}")
|
81 |
+
add_to_server_side(f"Received destination: {display_encrypted(st.session_state['encrypted_destination'])}")
|
82 |
st.rerun()
|
83 |
with c2:
|
84 |
+
st_data_final = display_map(nodes)
|
|
|
|
|
|
|
85 |
|
86 |
+
# Shortest path computation view
|
87 |
if 'encrypted_origin' in st.session_state and 'encrypted_shortest_path' not in st.session_state:
|
|
|
|
|
|
|
|
|
|
|
88 |
origin_node = st.session_state['origin_node']
|
89 |
destination_node = st.session_state['destination_node']
|
|
|
|
|
|
|
90 |
with c2:
|
91 |
+
st_data_final = display_map(nodes)
|
|
|
|
|
|
|
|
|
92 |
with c3:
|
93 |
+
if st.button('Compute and send shortest path'):
|
|
|
|
|
|
|
|
|
94 |
with st.spinner('Computing'):
|
95 |
+
encrypted_path = compute_shortest_path(nodes.shape[0], client, server)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
st.session_state['encrypted_shortest_path'] = encrypted_path
|
97 |
+
add_to_client_side("Received path is:")
|
98 |
+
add_to_client_side(f"{display_encrypted(st.session_state['encrypted_shortest_path'][0])}")
|
99 |
+
add_to_client_side("...")
|
100 |
+
add_to_client_side(f"{display_encrypted(st.session_state['encrypted_shortest_path'][-1])}")
|
101 |
st.rerun()
|
102 |
+
|
103 |
+
# Result decryption view
|
104 |
+
if 'encrypted_shortest_path' in st.session_state and 'decrypted_result' not in st.session_state :
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
with c1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
if st.button('Decrypt and show shortest path'):
|
107 |
with st.spinner('Computing'):
|
108 |
+
decrypt_shortest_path(client)
|
|
|
|
|
|
|
|
|
|
|
109 |
st.rerun()
|
110 |
with c2:
|
111 |
+
st_data_final = display_map(nodes)
|
112 |
|
113 |
+
# Display result view
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
if 'decrypted_result' in st.session_state :
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
with c1:
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
st.write(f"Decrypted result is: {st.session_state['decrypted_result']}")
|
117 |
+
restart_session()
|
|
|
|
|
|
|
118 |
with c2:
|
|
|
|
|
|
|
|
|
119 |
shortest_path_list = st.session_state['decrypted_result']
|
120 |
final_result = generate_path(shortest_path_list, rel)
|
121 |
+
display_map(nodes, path=final_result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
network.py
CHANGED
@@ -4,6 +4,17 @@ import numpy
|
|
4 |
|
5 |
|
6 |
def get_frames(ways: geopandas.GeoDataFrame):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
edges = ways.explode(index_parts=False)[["id", "name", "geometry"]]
|
8 |
edges.index.name = "way_id"
|
9 |
edges = edges.reset_index()
|
@@ -15,6 +26,14 @@ def get_frames(ways: geopandas.GeoDataFrame):
|
|
15 |
|
16 |
|
17 |
def get(ways: geopandas.GeoDataFrame) -> networkx.Graph:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
nodes, edges, rel = get_frames(ways)
|
19 |
graph = networkx.Graph()
|
20 |
for node in rel['node_id'].unique():
|
@@ -22,7 +41,6 @@ def get(ways: geopandas.GeoDataFrame) -> networkx.Graph:
|
|
22 |
for idx in rel['way_id'].unique():
|
23 |
node_id1, node_id2 = rel.loc[idx]['node_id']
|
24 |
way = edges.loc[idx]
|
25 |
-
way_length = rel.loc[idx]['geometry'].iloc[0].length
|
26 |
graph.add_edge(
|
27 |
node_id1, node_id2,
|
28 |
weight=way.geometry.length,
|
@@ -33,6 +51,14 @@ def get(ways: geopandas.GeoDataFrame) -> networkx.Graph:
|
|
33 |
|
34 |
|
35 |
def weighted_adjacency_matrix(graph: networkx.Graph) -> numpy.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
weights = numpy.full((graph.number_of_nodes(),)*2, numpy.inf)
|
37 |
numpy.fill_diagonal(weights, 0., wrap=False)
|
38 |
for i in graph:
|
|
|
4 |
|
5 |
|
6 |
def get_frames(ways: geopandas.GeoDataFrame):
|
7 |
+
"""Extract nodes, edges, and relations from a GeoDataFrame of ways.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
ways (geopandas.GeoDataFrame): A GeoDataFrame containing information about ways.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
nodes (geopandas.GeoDataFrame): A GeoDataFrame containing node geometries.
|
14 |
+
edges (geopandas.GeoDataFrame): A GeoDataFrame containing ways information.
|
15 |
+
rel (geopandas.GeoDataFrame): A GeoDataFrame representing relations between nodes and edges.
|
16 |
+
|
17 |
+
"""
|
18 |
edges = ways.explode(index_parts=False)[["id", "name", "geometry"]]
|
19 |
edges.index.name = "way_id"
|
20 |
edges = edges.reset_index()
|
|
|
26 |
|
27 |
|
28 |
def get(ways: geopandas.GeoDataFrame) -> networkx.Graph:
|
29 |
+
""" Convert a GeoDataFrame of ways into a networkx Graph.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
ways (geopandas.GeoDataFrame): A GeoDataFrame containing information about ways.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
networkx.Graph: A networkx Graph representing the geographical network.
|
36 |
+
"""
|
37 |
nodes, edges, rel = get_frames(ways)
|
38 |
graph = networkx.Graph()
|
39 |
for node in rel['node_id'].unique():
|
|
|
41 |
for idx in rel['way_id'].unique():
|
42 |
node_id1, node_id2 = rel.loc[idx]['node_id']
|
43 |
way = edges.loc[idx]
|
|
|
44 |
graph.add_edge(
|
45 |
node_id1, node_id2,
|
46 |
weight=way.geometry.length,
|
|
|
51 |
|
52 |
|
53 |
def weighted_adjacency_matrix(graph: networkx.Graph) -> numpy.ndarray:
|
54 |
+
"""Generate the weighted adjacency matrix from a networkx Graph.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
graph (networkx.Graph): A networkx Graph representing the network.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
numpy.ndarray: A NumPy array representing the weighted adjacency matrix.
|
61 |
+
"""
|
62 |
weights = numpy.full((graph.number_of_nodes(),)*2, numpy.inf)
|
63 |
numpy.fill_diagonal(weights, 0., wrap=False)
|
64 |
for i in graph:
|
utils.py
CHANGED
@@ -1,47 +1,220 @@
|
|
|
|
1 |
import pandas as pd
|
2 |
from concrete import fhe
|
3 |
-
from config import circuit_filepath
|
4 |
-
|
|
|
|
|
5 |
|
6 |
def set_up_server():
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
try:
|
9 |
server = fhe.Server.load(circuit_filepath)
|
10 |
-
except
|
11 |
-
raise
|
|
|
12 |
|
13 |
return server
|
14 |
|
15 |
def set_up_client(serialized_client_specs):
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
client_specs = fhe.ClientSpecs.deserialize(serialized_client_specs)
|
18 |
client = fhe.Client(client_specs)
|
19 |
|
20 |
return client
|
21 |
|
22 |
def display_encrypted(encrypted_object):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
encoded_text = encrypted_object.hex()
|
24 |
res = '...' + encoded_text[-10:]
|
25 |
return res
|
26 |
|
27 |
-
def
|
28 |
-
path
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
if origin == destination:
|
33 |
break
|
34 |
-
o =
|
35 |
-
origin =
|
36 |
-
|
37 |
-
return
|
38 |
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
pairs_list = []
|
41 |
-
for i in range(len(
|
42 |
-
current_element =
|
43 |
-
next_element =
|
44 |
-
result =
|
45 |
pairs_list.append(result)
|
46 |
final_result = pd.concat(pairs_list)
|
47 |
return final_result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contains a set of useful functions
|
2 |
import pandas as pd
|
3 |
from concrete import fhe
|
4 |
+
from config import circuit_filepath, keys_filepath
|
5 |
+
import streamlit as st
|
6 |
+
import folium
|
7 |
+
from streamlit_folium import st_folium
|
8 |
|
9 |
def set_up_server():
|
10 |
+
"""Load a server instance from a specified circuit file
|
11 |
+
|
12 |
+
Raises:
|
13 |
+
OSError: If there is an issue loading the FHE server.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
concrete.fhe.compilation.server.Server: A server instance loaded from the circuit file.
|
17 |
+
"""
|
18 |
try:
|
19 |
server = fhe.Server.load(circuit_filepath)
|
20 |
+
except OSError as e:
|
21 |
+
raise OSError(f"Something went wrong with the circuit. Make sure that the circuit \
|
22 |
+
exists in {circuit_filepath}.If not run python generate_circuit.py.") from e
|
23 |
|
24 |
return server
|
25 |
|
26 |
def set_up_client(serialized_client_specs):
|
27 |
+
"""Generate a client instance from a specified circuit file
|
28 |
+
|
29 |
+
Args:
|
30 |
+
serialized_client_specs (bytes): A serialized client specs
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
concrete.fhe.compilation.client.Client: A client instance created from the client specs
|
34 |
+
"""
|
35 |
+
|
36 |
client_specs = fhe.ClientSpecs.deserialize(serialized_client_specs)
|
37 |
client = fhe.Client(client_specs)
|
38 |
|
39 |
return client
|
40 |
|
41 |
def display_encrypted(encrypted_object):
|
42 |
+
"""Display a truncated representation of an encrypted object as a hexadecimal string
|
43 |
+
|
44 |
+
Args:
|
45 |
+
encrypted_object (bytes): A serialized encrypted object to display
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
str: A truncated hexadecimal representation of the encrypted object
|
49 |
+
"""
|
50 |
encoded_text = encrypted_object.hex()
|
51 |
res = '...' + encoded_text[-10:]
|
52 |
return res
|
53 |
|
54 |
+
def compute_shortest_path(nodes_nb, client, server):
|
55 |
+
"""Calculate the shortest path between two nodes
|
56 |
+
|
57 |
+
Args:
|
58 |
+
nodes_nb (int): The number of nodes in the network
|
59 |
+
client (concrete.fhe.compilation.client.Client): A client instance
|
60 |
+
server (concrete.fhe.compilation.server.Server): A server instance
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
List[bytes]: A list of encrypted values representing the path
|
64 |
+
"""
|
65 |
+
deserialized_origin = fhe.Value.deserialize(st.session_state['encrypted_origin'])
|
66 |
+
deserialized_destination = fhe.Value.deserialize(st.session_state['encrypted_destination'])
|
67 |
+
deserialized_evaluation_keys = fhe.EvaluationKeys.deserialize(st.session_state['evaluation_key'])
|
68 |
+
client.keys.load_if_exists_generate_and_save_otherwise(keys_filepath)
|
69 |
+
origin = st.session_state['origin_node']
|
70 |
+
destination = st.session_state['destination_node']
|
71 |
+
encrypted_path = [st.session_state['encrypted_origin'], ]
|
72 |
+
o, d = deserialized_origin, deserialized_destination
|
73 |
+
for _ in range(nodes_nb):
|
74 |
if origin == destination:
|
75 |
break
|
76 |
+
o = server.run(o, d, evaluation_keys=deserialized_evaluation_keys)
|
77 |
+
origin = client.decrypt(o)
|
78 |
+
encrypted_path.append(o.serialize())
|
79 |
+
return encrypted_path
|
80 |
|
81 |
+
|
82 |
+
def generate_path(shortest_path, roads):
|
83 |
+
"""Generate a path from a list of nodes using road data.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
shortest_path (List[int]): A list of nodes representing the shortest path.
|
87 |
+
roads (geopandas.DataFrame): A DataFrame containing the ways between the nodes.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
pandas.DataFrame: A DataFrame representing the road segments that form the shortest path.
|
91 |
+
"""
|
92 |
pairs_list = []
|
93 |
+
for i in range(len(shortest_path) - 1):
|
94 |
+
current_element = shortest_path[i]
|
95 |
+
next_element = shortest_path[i + 1]
|
96 |
+
result = roads.groupby('way_id').filter(lambda x: set([current_element, next_element]).issubset(x['node_id']))
|
97 |
pairs_list.append(result)
|
98 |
final_result = pd.concat(pairs_list)
|
99 |
return final_result
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
def decrypt_shortest_path(client):
|
105 |
+
"""Decrypt and store the shortest path in the session state
|
106 |
+
|
107 |
+
Args:
|
108 |
+
client (concrete.fhe.compilation.client.Client): A client instance
|
109 |
+
"""
|
110 |
+
client.keys.load_if_exists_generate_and_save_otherwise(keys_filepath)
|
111 |
+
path = []
|
112 |
+
for enc_value in st.session_state['encrypted_shortest_path']:
|
113 |
+
deserialized_result = fhe.Value.deserialize(enc_value)
|
114 |
+
path.append(client.decrypt(deserialized_result))
|
115 |
+
st.session_state['decrypted_result'] = path
|
116 |
+
|
117 |
+
def init_session():
|
118 |
+
"""Initialize the Streamlit session and layout configuration.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
Streamlit.columns: A tuple of Streamlit columns for layout customization.
|
122 |
+
"""
|
123 |
+
st.set_page_config(layout="wide")
|
124 |
+
|
125 |
+
if 'markers' not in st.session_state:
|
126 |
+
st.session_state['markers'] = []
|
127 |
+
if 'server_side' not in st.session_state:
|
128 |
+
st.session_state['server_side'] = []
|
129 |
+
if 'client_side' not in st.session_state:
|
130 |
+
st.session_state['client_side'] = []
|
131 |
+
|
132 |
+
c1, c2, c3 = st.columns([1, 3, 1])
|
133 |
+
|
134 |
+
return c1, c2, c3
|
135 |
+
|
136 |
+
def add_marker(coordinates, name):
|
137 |
+
"""Add a marker with coordinates and a name to the Streamlit session.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
coordinates (Point): The coordinates of the marker
|
141 |
+
name (str): The name or label for the marker
|
142 |
+
"""
|
143 |
+
data = {'coordinates': coordinates, 'name': name}
|
144 |
+
st.session_state['markers'].append(data)
|
145 |
+
|
146 |
+
|
147 |
+
def display_map(nodes, returned_objects=None, path=None):
|
148 |
+
"""Display the map with nodes and optional markers and paths.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
nodes (geopandas.DataFrame): A dataframe containing the nodes to display
|
152 |
+
returned_objects (List[str], optional): Objects to be returned when interacting with the map. Defaults to None.
|
153 |
+
path (pandas.DataFrame, optional): A DataFrame representing the road segments that form the shortest path. Defaults to None.
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
Streamlit.FoliumMap: An interactive map displaying nodes, markers, and paths
|
157 |
+
"""
|
158 |
+
m = nodes.explore(
|
159 |
+
color="red",
|
160 |
+
marker_kwds=dict(radius=5, fill=True, name='node_id'),
|
161 |
+
tooltip="node_id",
|
162 |
+
tooltip_kwds=dict(labels=False),
|
163 |
+
zoom_control=False,
|
164 |
+
)
|
165 |
+
|
166 |
+
if 'markers' in st.session_state:
|
167 |
+
for mrk in st.session_state['markers']:
|
168 |
+
folium.Marker([mrk['coordinates'].y, mrk['coordinates'].x], popup=mrk['name'], tooltip=mrk['name']).add_to(m)
|
169 |
+
|
170 |
+
if path is not None:
|
171 |
+
path.explore(
|
172 |
+
m=m,
|
173 |
+
color="green",
|
174 |
+
style_kwds = {"weight":5},
|
175 |
+
tooltip="name",
|
176 |
+
popup=["name"],
|
177 |
+
name="Quadratic-Paris",
|
178 |
+
)
|
179 |
+
|
180 |
+
return st_folium(m, width=725, key="origin", returned_objects=returned_objects)
|
181 |
+
|
182 |
+
|
183 |
+
def add_to_server_side(message):
|
184 |
+
"""Add a message to the server side of the view
|
185 |
+
|
186 |
+
Args:
|
187 |
+
message (str): The message to be added to the server side
|
188 |
+
"""
|
189 |
+
st.session_state['server_side'].append(message)
|
190 |
+
|
191 |
+
def add_to_client_side(message):
|
192 |
+
"""Add a message to the client side of the view
|
193 |
+
|
194 |
+
Args:
|
195 |
+
message (str): The message to be added to the client side
|
196 |
+
"""
|
197 |
+
st.session_state['client_side'].append(message)
|
198 |
+
|
199 |
+
def display_server_side():
|
200 |
+
"""Display the messages stored in the server-side view.
|
201 |
+
"""
|
202 |
+
st.write("**Server-side**")
|
203 |
+
for message in st.session_state['server_side']:
|
204 |
+
st.write(message)
|
205 |
+
|
206 |
+
def display_client_side():
|
207 |
+
"""Display the messages stored in the client-side view.
|
208 |
+
"""
|
209 |
+
st.write("**Client-side**")
|
210 |
+
for message in st.session_state['client_side']:
|
211 |
+
st.write(message)
|
212 |
+
|
213 |
+
def restart_session():
|
214 |
+
"""Clear the session state to restart
|
215 |
+
"""
|
216 |
+
if st.button('Restart'):
|
217 |
+
for key in st.session_state.items():
|
218 |
+
if key != 'evaluation_key':
|
219 |
+
del st.session_state[key]
|
220 |
+
st.rerun()
|