raulminan's picture
Update app.py
6ef01d4
import dash
import dash_bootstrap_components as dbc
from dash import dcc
from dash import html
from dash.dependencies import Input, Output, State
from typing import List, Tuple
from scipy.spatial.distance import cdist
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
df = pd.read_pickle('all_embeddings_with_splits.p')
app = dash.Dash(external_stylesheets=[dbc.themes.BOOTSTRAP])
app.layout = dbc.Container(
[
html.H1("Embedding Plots"),
html.Hr(),
html.Div(
[
dbc.Row(
[
dbc.Col(
[
html.Label('Algorithm:'),
dcc.Dropdown(
id="algorithm-dropdown",
options=[
{"label": "PCA", "value": "pca"},
{"label": "UMAP", "value": "umap"},
{"label": "tSNE", "value": "tsne"},
{"label": "PaCMAP", "value": "pacmap"},
],
value="pacmap",
clearable=False,
searchable=False,
style={"margin-bottom": "10px"}
),
html.Label('Number of dimensions:'),
dcc.Dropdown(
id="num-components-dropdown",
options=[
{"label": "2", "value": 2},
{"label": "3", "value": 3}
],
value=3,
clearable=False,
searchable=False,
style={"margin-bottom": "10px"}
),
html.Label('Color by:'),
dcc.Dropdown(
id="color-by",
options=[
{
"label": "Protein Classification",
"value": "classification"
},
{
"label": "Split (train/test/val/gpcr)",
"value": "split"
}
],
value="classification",
clearable=False,
searchable=False,
style={"margin-bottom": "10px"}
),
html.Span(
[
"Keep the top ",
dcc.Input(
id="top-n-classes",
type="number",
value=10,
min=1,
max=len(df["classification"].unique()),
step=1,
style={"width": "50px"}
),
" classes."
],
style={"margin-bottom": "20px"}
),
html.Br(),
dbc.Button(
"Update",
id="update-button",
color="primary",
n_clicks=0,
style={"width": "100%", "margin": "10px 0px"}
),
dbc.Container(
id="closest-points",
style={"max-height": "65vh", "overflow-y": "auto"}
),
],
width={"size": 2, "order": 1},
),
dbc.Col(
dcc.Graph(
id="embedding-graph",
style={"height": "100%", "width": "100%"},
),
width={"size": 10, "order": 2},
),
],
style={"height":"95vh"}
)
],
style={"height":"100hv"}
),
html.Hr(),
],
fluid=True,
)
def load_embedding(algorithm: str, num_components: int) -> np.array:
"""Loads the embeddings given an algorithm and number of dimensions.
Parameters
----------
algorithm : str
Algorithm used
num_components : int
see param name
Returns
-------
np.array
A Ax1280 numpy matrix with the embeddings.
"""
if algorithm == "pca":
embedding = np.load("pca.npy")
else:
embedding = np.load(f"{algorithm}{str(num_components)}d.npy")
return embedding
def get_top_n_classifications(df: pd.DataFrame, n: int) -> List[str]:
return df["classification"].value_counts().nlargest(n).index.tolist()
@app.callback(
Output("embedding-graph", "figure"),
[
Input("update-button", "n_clicks"),
],
[
State("algorithm-dropdown", "value"),
State("num-components-dropdown", "value"),
State("top-n-classes", "value"),
State("color-by", "value"),
]
)
def update_embedding_graph(n_clicks: int,
algorithm: str,
num_components: int,
top_n_classes: int,
color_by: str) -> go.Figure:
if n_clicks > 0:
embedding = load_embedding(algorithm, num_components)
if color_by == "split":
color_map = {
"gpcr": "red",
"train": "blue",
"val": "green",
"test": "orange",
"unknown": "grey",
}
color_series = df["splits"].copy()
df["color_series"] = color_series
else:
top_classes = get_top_n_classifications(df, n=top_n_classes)
is_top_n = df["classification"].isin(top_classes)
color_series = df["classification"].copy()
color_series[~is_top_n] = "other"
df["color_series"] = color_series
top_n_colors = px.colors.qualitative.Plotly[:top_n_classes]
color_map_top = {c: top_n_colors[i] for i, c in enumerate(top_classes)}
color_map = {c: color_map_top[c] if c in top_classes else 'grey' for i, c in enumerate(set(df['color_series']))}
if num_components == 3:
fig = go.Figure()
for c in df["color_series"].unique():
class_indices = np.where(df["color_series"] == c)[0]
data = embedding[class_indices]
fig.add_trace(
go.Scatter3d(
x=data[:,0],
y=data[:,1],
z=data[:,2],
mode='markers',
name=c,
marker=dict(
size=2.5,
color=color_map[c],
opacity=1 if color_map[c] != 'grey' else 0.3,
),
hovertemplate=
"<b>PDB ID</b>: %{customdata[0]}<br>" +
"<b>Classification</b>: %{customdata[1]}<br>" +
"<extra></extra>",
customdata=df.iloc[class_indices][['pdb_id', 'classification']]
)
)
fig.update_layout(
scene=dict(
xaxis=dict(showgrid=False, showticklabels=False, title=""),
yaxis=dict(showgrid=False, showticklabels=False, title=""),
zaxis=dict(showgrid=False, showticklabels=False, title=""),
),
)
fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )
elif num_components == 2:
fig = go.Figure()
for c in df["color_series"].unique():
class_indices = np.where(df["color_series"] == c)[0]
data = embedding[class_indices]
fig.add_trace(
go.Scatter(
x=data[:,0],
y=data[:,1],
mode='markers',
name=c,
marker=dict(
size=2.5,
color=color_map[c],
opacity=1 if color_map[c] != 'grey' else 0.3,
),
hovertemplate=
"<b>PDB ID</b>: %{customdata[0]}<br>" +
"<b>Classification</b>: %{customdata[1]}<br>"
"<extra></extra>",
customdata=df.iloc[class_indices][['pdb_id', 'classification']]
)
)
fig.update_traces(marker=dict(size=7.5), selector=dict(mode='markers'))
fig.update_scenes(xaxis_visible=False, yaxis_visible=False)
fig.update_layout(
legend=dict(
x=0,
y=1,
itemsizing='constant',
itemclick='toggle',
itemdoubleclick='toggleothers',
traceorder='reversed',
itemwidth=30,
),
margin=dict(l=0, r=0, b=0, t=0),
plot_bgcolor='rgba(0,0,0,0)',
paper_bgcolor='rgba(0,0,0,0)',
)
return fig
else:
raise dash.exceptions.PreventUpdate
#### GET CLOSEST POINTS
def extract_info_from_clickData(clickData: dict) -> Tuple[str, str]:
"""Extracts information from a clickData dictionary coming from clicking
a point in a scatter plot.
Speficially, it retrieves the pdb_id and the classification.
Shape of clickData:
{
"points": [
{
"x": 11.330583,
"y": 15.741333,
"z": -5.3435574,
"curveNumber": 2,
"pointNumber": 982,
"bbox": {
"x0": 704.3911532022826,
"x1": 704.3911532022826,
"y0": 393.5066681413661,
"y1": 393.5066681413661
},
"customdata": [
"1zfp",
"complex (signal transduction/peptide)"
]
}
]
}
Parameters
----------
clickData : dict
Contains the information of a point on a go.Figure graph.
Returns
-------
Tuple[]
_description_
"""
pdb_id = clickData["points"][0]["customdata"][0]
classification = clickData["points"][0]["customdata"][1]
return pdb_id, classification
def find_closest_n_points(df: pd.DataFrame,
embedding: np.array,
index: int = None,
pdb_id: str = None,
n: int = 20) -> Tuple[list, list]:
"""
Given an embedding array and a point index or pdb_id, finds the n closest
points to the given point.
Parameters:
-----------
embedding: np.ndarray
A 2D numpy array with the embedding coordinates.
point_index: int
The index of the point to which we want to find the closest points.
n: int
The number of closest points to retrieve.
Returns:
--------
closest_indices: list
A list with the indices of the n closest points to the given point.
"""
if pdb_id:
index = df.index[df["pdb_id"] == pdb_id].item()
distances = cdist(embedding[index, np.newaxis], embedding)
closest_indices = np.argsort(distances)[0][:n]
closest_ids = df.iloc[closest_indices]["pdb_id"].tolist()
closest_ids_classifications = df.iloc[closest_indices]["classification"].tolist()
return closest_ids, closest_ids_classifications
@app.callback(
Output("closest-points", "children"),
[
Input("embedding-graph", "clickData")
],
[
State("algorithm-dropdown", "value"),
State("num-components-dropdown", "value"),
]
)
def update_closest_points_div(
clickData: dict,
algorithm: str,
num_components: int) -> html.Table:
embedding = load_embedding(algorithm, num_components)
if clickData is not None:
pdb_id, _ = extract_info_from_clickData(clickData)
index = df.index[df["pdb_id"] == pdb_id].item()
closest_ids, closest_ids_classifications = find_closest_n_points(
df, embedding, index)
cards = []
for i in range(len(closest_ids)):
card = dbc.Card(
dbc.CardBody(
[
html.P(closest_ids[i], className="card-title"),
html.P(closest_ids_classifications[i], className="card-text"),
]
),
className="mb-3",
)
cards.append(card)
return cards
return html.Div(id="closest-points", children=[html.Div("Click on a data point to see the closest points.")])
if __name__ == "__main__":
app.run_server(debug=False, host='0.0.0.0', port=7680)