Spaces:
Runtime error
Runtime error
Upload 3 files
Browse files- Dockerfile +12 -0
- app.py +389 -0
- requirements.txt +8 -0
Dockerfile
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
COPY requirements.txt ./
|
6 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
7 |
+
|
8 |
+
COPY . .
|
9 |
+
COPY all_embeddings.p all_embeddings_with_splits.p app.py embeddings.p pacmap2d.npy pacmap3d.npy pca.npy tsne2d.npy tsne3d.npy umap2d.npy umap3d.npy ./
|
10 |
+
|
11 |
+
CMD ["python", "app.py", "--host", "0.0.0.0", "--port", "7680"]
|
12 |
+
|
app.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dash
|
2 |
+
import dash_bootstrap_components as dbc
|
3 |
+
from dash import dcc
|
4 |
+
from dash import html
|
5 |
+
from dash.dependencies import Input, Output, State
|
6 |
+
|
7 |
+
from typing import List, Tuple
|
8 |
+
from scipy.spatial.distance import cdist
|
9 |
+
|
10 |
+
import pandas as pd
|
11 |
+
import numpy as np
|
12 |
+
import plotly.express as px
|
13 |
+
import plotly.graph_objects as go
|
14 |
+
|
15 |
+
|
16 |
+
df = pd.read_pickle('all_embeddings_with_splits.p')
|
17 |
+
|
18 |
+
app = dash.Dash(external_stylesheets=[dbc.themes.BOOTSTRAP])
|
19 |
+
app.layout = dbc.Container(
|
20 |
+
[
|
21 |
+
html.H1("Embedding Plots"),
|
22 |
+
html.Hr(),
|
23 |
+
html.Div(
|
24 |
+
[
|
25 |
+
dbc.Row(
|
26 |
+
[
|
27 |
+
dbc.Col(
|
28 |
+
[
|
29 |
+
html.Label('Algorithm:'),
|
30 |
+
dcc.Dropdown(
|
31 |
+
id="algorithm-dropdown",
|
32 |
+
options=[
|
33 |
+
{"label": "PCA", "value": "pca"},
|
34 |
+
{"label": "UMAP", "value": "umap"},
|
35 |
+
{"label": "tSNE", "value": "tsne"},
|
36 |
+
{"label": "PaCMAP", "value": "pacmap"},
|
37 |
+
],
|
38 |
+
value="pacmap",
|
39 |
+
clearable=False,
|
40 |
+
searchable=False,
|
41 |
+
style={"margin-bottom": "10px"}
|
42 |
+
),
|
43 |
+
html.Label('Number of dimensions:'),
|
44 |
+
dcc.Dropdown(
|
45 |
+
id="num-components-dropdown",
|
46 |
+
options=[
|
47 |
+
{"label": "2", "value": 2},
|
48 |
+
{"label": "3", "value": 3}
|
49 |
+
],
|
50 |
+
value=3,
|
51 |
+
clearable=False,
|
52 |
+
searchable=False,
|
53 |
+
style={"margin-bottom": "10px"}
|
54 |
+
),
|
55 |
+
html.Label('Color by:'),
|
56 |
+
dcc.Dropdown(
|
57 |
+
id="color-by",
|
58 |
+
options=[
|
59 |
+
{
|
60 |
+
"label": "Protein Classification",
|
61 |
+
"value": "classification"
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"label": "Split (train/test/val/gpcr)",
|
65 |
+
"value": "split"
|
66 |
+
}
|
67 |
+
],
|
68 |
+
value="classification",
|
69 |
+
clearable=False,
|
70 |
+
searchable=False,
|
71 |
+
style={"margin-bottom": "10px"}
|
72 |
+
),
|
73 |
+
html.Span(
|
74 |
+
[
|
75 |
+
"Keep the top ",
|
76 |
+
dcc.Input(
|
77 |
+
id="top-n-classes",
|
78 |
+
type="number",
|
79 |
+
value=10,
|
80 |
+
min=1,
|
81 |
+
max=len(df["classification"].unique()),
|
82 |
+
step=1,
|
83 |
+
style={"width": "50px"}
|
84 |
+
),
|
85 |
+
" classes."
|
86 |
+
],
|
87 |
+
style={"margin-bottom": "20px"}
|
88 |
+
),
|
89 |
+
html.Br(),
|
90 |
+
dbc.Button(
|
91 |
+
"Update",
|
92 |
+
id="update-button",
|
93 |
+
color="primary",
|
94 |
+
n_clicks=0,
|
95 |
+
style={"width": "100%", "margin": "10px 0px"}
|
96 |
+
),
|
97 |
+
dbc.Container(
|
98 |
+
id="closest-points",
|
99 |
+
style={"max-height": "65vh", "overflow-y": "auto"}
|
100 |
+
),
|
101 |
+
],
|
102 |
+
width={"size": 2, "order": 1},
|
103 |
+
),
|
104 |
+
dbc.Col(
|
105 |
+
dcc.Graph(
|
106 |
+
id="embedding-graph",
|
107 |
+
style={"height": "100%", "width": "100%"},
|
108 |
+
),
|
109 |
+
width={"size": 10, "order": 2},
|
110 |
+
),
|
111 |
+
],
|
112 |
+
style={"height":"95vh"}
|
113 |
+
)
|
114 |
+
],
|
115 |
+
style={"height":"100hv"}
|
116 |
+
),
|
117 |
+
html.Hr(),
|
118 |
+
],
|
119 |
+
fluid=True,
|
120 |
+
)
|
121 |
+
|
122 |
+
def load_embedding(algorithm: str, num_components: int) -> np.array:
|
123 |
+
"""Loads the embeddings given an algorithm and number of dimensions.
|
124 |
+
|
125 |
+
Parameters
|
126 |
+
----------
|
127 |
+
algorithm : str
|
128 |
+
Algorithm used
|
129 |
+
num_components : int
|
130 |
+
see param name
|
131 |
+
|
132 |
+
Returns
|
133 |
+
-------
|
134 |
+
np.array
|
135 |
+
A Ax1280 numpy matrix with the embeddings.
|
136 |
+
"""
|
137 |
+
if algorithm == "pca":
|
138 |
+
embedding = np.load("pca.npy")
|
139 |
+
else:
|
140 |
+
embedding = np.load(f"{algorithm}{str(num_components)}d.npy")
|
141 |
+
return embedding
|
142 |
+
|
143 |
+
def get_top_n_classifications(df: pd.DataFrame, n: int) -> List[str]:
|
144 |
+
return df["classification"].value_counts().nlargest(n).index.tolist()
|
145 |
+
|
146 |
+
@app.callback(
|
147 |
+
Output("embedding-graph", "figure"),
|
148 |
+
[
|
149 |
+
Input("update-button", "n_clicks"),
|
150 |
+
],
|
151 |
+
[
|
152 |
+
State("algorithm-dropdown", "value"),
|
153 |
+
State("num-components-dropdown", "value"),
|
154 |
+
State("top-n-classes", "value"),
|
155 |
+
State("color-by", "value"),
|
156 |
+
]
|
157 |
+
)
|
158 |
+
def update_embedding_graph(n_clicks: int,
|
159 |
+
algorithm: str,
|
160 |
+
num_components: int,
|
161 |
+
top_n_classes: int,
|
162 |
+
color_by: str) -> go.Figure:
|
163 |
+
if n_clicks > 0:
|
164 |
+
embedding = load_embedding(algorithm, num_components)
|
165 |
+
|
166 |
+
if color_by == "split":
|
167 |
+
color_map = {
|
168 |
+
"gpcr": "red",
|
169 |
+
"train": "blue",
|
170 |
+
"val": "green",
|
171 |
+
"test": "orange",
|
172 |
+
"unknown": "grey",
|
173 |
+
}
|
174 |
+
color_series = df["splits"].copy()
|
175 |
+
df["color_series"] = color_series
|
176 |
+
else:
|
177 |
+
top_classes = get_top_n_classifications(df, n=top_n_classes)
|
178 |
+
is_top_n = df["classification"].isin(top_classes)
|
179 |
+
color_series = df["classification"].copy()
|
180 |
+
color_series[~is_top_n] = "other"
|
181 |
+
df["color_series"] = color_series
|
182 |
+
top_n_colors = px.colors.qualitative.Plotly[:top_n_classes]
|
183 |
+
color_map_top = {c: top_n_colors[i] for i, c in enumerate(top_classes)}
|
184 |
+
color_map = {c: color_map_top[c] if c in top_classes else 'grey' for i, c in enumerate(set(df['color_series']))}
|
185 |
+
|
186 |
+
|
187 |
+
if num_components == 3:
|
188 |
+
fig = go.Figure()
|
189 |
+
for c in df["color_series"].unique():
|
190 |
+
class_indices = np.where(df["color_series"] == c)[0]
|
191 |
+
data = embedding[class_indices]
|
192 |
+
fig.add_trace(
|
193 |
+
go.Scatter3d(
|
194 |
+
x=data[:,0],
|
195 |
+
y=data[:,1],
|
196 |
+
z=data[:,2],
|
197 |
+
mode='markers',
|
198 |
+
name=c,
|
199 |
+
marker=dict(
|
200 |
+
size=2.5,
|
201 |
+
color=color_map[c],
|
202 |
+
opacity=1 if color_map[c] != 'grey' else 0.3,
|
203 |
+
),
|
204 |
+
hovertemplate=
|
205 |
+
"<b>PDB ID</b>: %{customdata[0]}<br>" +
|
206 |
+
"<b>Classification</b>: %{customdata[1]}<br>" +
|
207 |
+
"<extra></extra>",
|
208 |
+
customdata=df.iloc[class_indices][['pdb_id', 'classification']]
|
209 |
+
)
|
210 |
+
)
|
211 |
+
|
212 |
+
fig.update_layout(
|
213 |
+
scene=dict(
|
214 |
+
xaxis=dict(showgrid=False, showticklabels=False, title=""),
|
215 |
+
yaxis=dict(showgrid=False, showticklabels=False, title=""),
|
216 |
+
zaxis=dict(showgrid=False, showticklabels=False, title=""),
|
217 |
+
),
|
218 |
+
)
|
219 |
+
fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )
|
220 |
+
|
221 |
+
elif num_components == 2:
|
222 |
+
fig = go.Figure()
|
223 |
+
for c in df["color_series"].unique():
|
224 |
+
class_indices = np.where(df["color_series"] == c)[0]
|
225 |
+
data = embedding[class_indices]
|
226 |
+
fig.add_trace(
|
227 |
+
go.Scatter(
|
228 |
+
x=data[:,0],
|
229 |
+
y=data[:,1],
|
230 |
+
mode='markers',
|
231 |
+
name=c,
|
232 |
+
marker=dict(
|
233 |
+
size=2.5,
|
234 |
+
color=color_map[c],
|
235 |
+
opacity=1 if color_map[c] != 'grey' else 0.3,
|
236 |
+
),
|
237 |
+
hovertemplate=
|
238 |
+
"<b>PDB ID</b>: %{customdata[0]}<br>" +
|
239 |
+
"<b>Classification</b>: %{customdata[1]}<br>"
|
240 |
+
"<extra></extra>",
|
241 |
+
customdata=df.iloc[class_indices][['pdb_id', 'classification']]
|
242 |
+
)
|
243 |
+
)
|
244 |
+
fig.update_traces(marker=dict(size=7.5), selector=dict(mode='markers'))
|
245 |
+
fig.update_scenes(xaxis_visible=False, yaxis_visible=False)
|
246 |
+
|
247 |
+
fig.update_layout(
|
248 |
+
legend=dict(
|
249 |
+
x=0,
|
250 |
+
y=1,
|
251 |
+
itemsizing='constant',
|
252 |
+
itemclick='toggle',
|
253 |
+
itemdoubleclick='toggleothers',
|
254 |
+
traceorder='reversed',
|
255 |
+
itemwidth=30,
|
256 |
+
),
|
257 |
+
margin=dict(l=0, r=0, b=0, t=0),
|
258 |
+
plot_bgcolor='rgba(0,0,0,0)',
|
259 |
+
paper_bgcolor='rgba(0,0,0,0)',
|
260 |
+
)
|
261 |
+
return fig
|
262 |
+
|
263 |
+
else:
|
264 |
+
raise dash.exceptions.PreventUpdate
|
265 |
+
|
266 |
+
#### GET CLOSEST POINTS
|
267 |
+
|
268 |
+
def extract_info_from_clickData(clickData: dict) -> Tuple[str, str]:
|
269 |
+
"""Extracts information from a clickData dictionary coming from clicking
|
270 |
+
a point in a scatter plot.
|
271 |
+
|
272 |
+
Speficially, it retrieves the pdb_id and the classification.
|
273 |
+
|
274 |
+
Shape of clickData:
|
275 |
+
|
276 |
+
{
|
277 |
+
"points": [
|
278 |
+
{
|
279 |
+
"x": 11.330583,
|
280 |
+
"y": 15.741333,
|
281 |
+
"z": -5.3435574,
|
282 |
+
"curveNumber": 2,
|
283 |
+
"pointNumber": 982,
|
284 |
+
"bbox": {
|
285 |
+
"x0": 704.3911532022826,
|
286 |
+
"x1": 704.3911532022826,
|
287 |
+
"y0": 393.5066681413661,
|
288 |
+
"y1": 393.5066681413661
|
289 |
+
},
|
290 |
+
"customdata": [
|
291 |
+
"1zfp",
|
292 |
+
"complex (signal transduction/peptide)"
|
293 |
+
]
|
294 |
+
}
|
295 |
+
]
|
296 |
+
}
|
297 |
+
|
298 |
+
Parameters
|
299 |
+
----------
|
300 |
+
clickData : dict
|
301 |
+
Contains the information of a point on a go.Figure graph.
|
302 |
+
|
303 |
+
Returns
|
304 |
+
-------
|
305 |
+
Tuple[]
|
306 |
+
_description_
|
307 |
+
"""
|
308 |
+
pdb_id = clickData["points"][0]["customdata"][0]
|
309 |
+
classification = clickData["points"][0]["customdata"][1]
|
310 |
+
|
311 |
+
return pdb_id, classification
|
312 |
+
|
313 |
+
def find_closest_n_points(df: pd.DataFrame,
|
314 |
+
embedding: np.array,
|
315 |
+
index: int = None,
|
316 |
+
pdb_id: str = None,
|
317 |
+
n: int = 20) -> Tuple[list, list]:
|
318 |
+
"""
|
319 |
+
Given an embedding array and a point index or pdb_id, finds the n closest
|
320 |
+
points to the given point.
|
321 |
+
|
322 |
+
Parameters:
|
323 |
+
-----------
|
324 |
+
embedding: np.ndarray
|
325 |
+
A 2D numpy array with the embedding coordinates.
|
326 |
+
point_index: int
|
327 |
+
The index of the point to which we want to find the closest points.
|
328 |
+
n: int
|
329 |
+
The number of closest points to retrieve.
|
330 |
+
|
331 |
+
Returns:
|
332 |
+
--------
|
333 |
+
closest_indices: list
|
334 |
+
A list with the indices of the n closest points to the given point.
|
335 |
+
"""
|
336 |
+
if pdb_id:
|
337 |
+
index = df.index[df["pdb_id"] == pdb_id].item()
|
338 |
+
|
339 |
+
distances = cdist(embedding[index, np.newaxis], embedding)
|
340 |
+
closest_indices = np.argsort(distances)[0][:n]
|
341 |
+
closest_ids = df.iloc[closest_indices]["pdb_id"].tolist()
|
342 |
+
closest_ids_classifications = df.iloc[closest_indices]["classification"].tolist()
|
343 |
+
|
344 |
+
return closest_ids, closest_ids_classifications
|
345 |
+
|
346 |
+
|
347 |
+
@app.callback(
|
348 |
+
Output("closest-points", "children"),
|
349 |
+
[
|
350 |
+
Input("embedding-graph", "clickData")
|
351 |
+
],
|
352 |
+
[
|
353 |
+
State("algorithm-dropdown", "value"),
|
354 |
+
State("num-components-dropdown", "value"),
|
355 |
+
]
|
356 |
+
)
|
357 |
+
def update_closest_points_div(
|
358 |
+
clickData: dict,
|
359 |
+
algorithm: str,
|
360 |
+
num_components: int) -> html.Table:
|
361 |
+
|
362 |
+
embedding = load_embedding(algorithm, num_components)
|
363 |
+
|
364 |
+
if clickData is not None:
|
365 |
+
pdb_id, _ = extract_info_from_clickData(clickData)
|
366 |
+
index = df.index[df["pdb_id"] == pdb_id].item()
|
367 |
+
closest_ids, closest_ids_classifications = find_closest_n_points(
|
368 |
+
df, embedding, index)
|
369 |
+
|
370 |
+
cards = []
|
371 |
+
for i in range(len(closest_ids)):
|
372 |
+
card = dbc.Card(
|
373 |
+
dbc.CardBody(
|
374 |
+
[
|
375 |
+
html.P(closest_ids[i], className="card-title"),
|
376 |
+
html.P(closest_ids_classifications[i], className="card-text"),
|
377 |
+
]
|
378 |
+
),
|
379 |
+
className="mb-3",
|
380 |
+
)
|
381 |
+
cards.append(card)
|
382 |
+
|
383 |
+
return cards
|
384 |
+
|
385 |
+
return html.Div(id="closest-points", children=[html.Div("Click on a data point to see the closest points.")])
|
386 |
+
|
387 |
+
|
388 |
+
if __name__ == "__main__":
|
389 |
+
app.run_server(debug=True)
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dash==2.9.3
|
2 |
+
dash-bootstrap-components==1.4.1
|
3 |
+
dash-core-components==2.0.0
|
4 |
+
dash-html-components==2.0.0
|
5 |
+
plotly==5.14.1
|
6 |
+
numpy==1.23.5
|
7 |
+
pandas==1.5.0
|
8 |
+
scipy==1.10.0
|