import dash import dash_bootstrap_components as dbc from dash import dcc from dash import html from dash.dependencies import Input, Output, State from typing import List, Tuple from scipy.spatial.distance import cdist import pandas as pd import numpy as np import plotly.express as px import plotly.graph_objects as go df = pd.read_pickle('all_embeddings_with_splits.p') app = dash.Dash(external_stylesheets=[dbc.themes.BOOTSTRAP]) app.layout = dbc.Container( [ html.H1("Embedding Plots"), html.Hr(), html.Div( [ dbc.Row( [ dbc.Col( [ html.Label('Algorithm:'), dcc.Dropdown( id="algorithm-dropdown", options=[ {"label": "PCA", "value": "pca"}, {"label": "UMAP", "value": "umap"}, {"label": "tSNE", "value": "tsne"}, {"label": "PaCMAP", "value": "pacmap"}, ], value="pacmap", clearable=False, searchable=False, style={"margin-bottom": "10px"} ), html.Label('Number of dimensions:'), dcc.Dropdown( id="num-components-dropdown", options=[ {"label": "2", "value": 2}, {"label": "3", "value": 3} ], value=3, clearable=False, searchable=False, style={"margin-bottom": "10px"} ), html.Label('Color by:'), dcc.Dropdown( id="color-by", options=[ { "label": "Protein Classification", "value": "classification" }, { "label": "Split (train/test/val/gpcr)", "value": "split" } ], value="classification", clearable=False, searchable=False, style={"margin-bottom": "10px"} ), html.Span( [ "Keep the top ", dcc.Input( id="top-n-classes", type="number", value=10, min=1, max=len(df["classification"].unique()), step=1, style={"width": "50px"} ), " classes." ], style={"margin-bottom": "20px"} ), html.Br(), dbc.Button( "Update", id="update-button", color="primary", n_clicks=0, style={"width": "100%", "margin": "10px 0px"} ), dbc.Container( id="closest-points", style={"max-height": "65vh", "overflow-y": "auto"} ), ], width={"size": 2, "order": 1}, ), dbc.Col( dcc.Graph( id="embedding-graph", style={"height": "100%", "width": "100%"}, ), width={"size": 10, "order": 2}, ), ], style={"height":"95vh"} ) ], style={"height":"100hv"} ), html.Hr(), ], fluid=True, ) def load_embedding(algorithm: str, num_components: int) -> np.array: """Loads the embeddings given an algorithm and number of dimensions. Parameters ---------- algorithm : str Algorithm used num_components : int see param name Returns ------- np.array A Ax1280 numpy matrix with the embeddings. """ if algorithm == "pca": embedding = np.load("pca.npy") else: embedding = np.load(f"{algorithm}{str(num_components)}d.npy") return embedding def get_top_n_classifications(df: pd.DataFrame, n: int) -> List[str]: return df["classification"].value_counts().nlargest(n).index.tolist() @app.callback( Output("embedding-graph", "figure"), [ Input("update-button", "n_clicks"), ], [ State("algorithm-dropdown", "value"), State("num-components-dropdown", "value"), State("top-n-classes", "value"), State("color-by", "value"), ] ) def update_embedding_graph(n_clicks: int, algorithm: str, num_components: int, top_n_classes: int, color_by: str) -> go.Figure: if n_clicks > 0: embedding = load_embedding(algorithm, num_components) if color_by == "split": color_map = { "gpcr": "red", "train": "blue", "val": "green", "test": "orange", "unknown": "grey", } color_series = df["splits"].copy() df["color_series"] = color_series else: top_classes = get_top_n_classifications(df, n=top_n_classes) is_top_n = df["classification"].isin(top_classes) color_series = df["classification"].copy() color_series[~is_top_n] = "other" df["color_series"] = color_series top_n_colors = px.colors.qualitative.Plotly[:top_n_classes] color_map_top = {c: top_n_colors[i] for i, c in enumerate(top_classes)} color_map = {c: color_map_top[c] if c in top_classes else 'grey' for i, c in enumerate(set(df['color_series']))} if num_components == 3: fig = go.Figure() for c in df["color_series"].unique(): class_indices = np.where(df["color_series"] == c)[0] data = embedding[class_indices] fig.add_trace( go.Scatter3d( x=data[:,0], y=data[:,1], z=data[:,2], mode='markers', name=c, marker=dict( size=2.5, color=color_map[c], opacity=1 if color_map[c] != 'grey' else 0.3, ), hovertemplate= "PDB ID: %{customdata[0]}
" + "Classification: %{customdata[1]}
" + "", customdata=df.iloc[class_indices][['pdb_id', 'classification']] ) ) fig.update_layout( scene=dict( xaxis=dict(showgrid=False, showticklabels=False, title=""), yaxis=dict(showgrid=False, showticklabels=False, title=""), zaxis=dict(showgrid=False, showticklabels=False, title=""), ), ) fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False ) elif num_components == 2: fig = go.Figure() for c in df["color_series"].unique(): class_indices = np.where(df["color_series"] == c)[0] data = embedding[class_indices] fig.add_trace( go.Scatter( x=data[:,0], y=data[:,1], mode='markers', name=c, marker=dict( size=2.5, color=color_map[c], opacity=1 if color_map[c] != 'grey' else 0.3, ), hovertemplate= "PDB ID: %{customdata[0]}
" + "Classification: %{customdata[1]}
" "", customdata=df.iloc[class_indices][['pdb_id', 'classification']] ) ) fig.update_traces(marker=dict(size=7.5), selector=dict(mode='markers')) fig.update_scenes(xaxis_visible=False, yaxis_visible=False) fig.update_layout( legend=dict( x=0, y=1, itemsizing='constant', itemclick='toggle', itemdoubleclick='toggleothers', traceorder='reversed', itemwidth=30, ), margin=dict(l=0, r=0, b=0, t=0), plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)', ) return fig else: raise dash.exceptions.PreventUpdate #### GET CLOSEST POINTS def extract_info_from_clickData(clickData: dict) -> Tuple[str, str]: """Extracts information from a clickData dictionary coming from clicking a point in a scatter plot. Speficially, it retrieves the pdb_id and the classification. Shape of clickData: { "points": [ { "x": 11.330583, "y": 15.741333, "z": -5.3435574, "curveNumber": 2, "pointNumber": 982, "bbox": { "x0": 704.3911532022826, "x1": 704.3911532022826, "y0": 393.5066681413661, "y1": 393.5066681413661 }, "customdata": [ "1zfp", "complex (signal transduction/peptide)" ] } ] } Parameters ---------- clickData : dict Contains the information of a point on a go.Figure graph. Returns ------- Tuple[] _description_ """ pdb_id = clickData["points"][0]["customdata"][0] classification = clickData["points"][0]["customdata"][1] return pdb_id, classification def find_closest_n_points(df: pd.DataFrame, embedding: np.array, index: int = None, pdb_id: str = None, n: int = 20) -> Tuple[list, list]: """ Given an embedding array and a point index or pdb_id, finds the n closest points to the given point. Parameters: ----------- embedding: np.ndarray A 2D numpy array with the embedding coordinates. point_index: int The index of the point to which we want to find the closest points. n: int The number of closest points to retrieve. Returns: -------- closest_indices: list A list with the indices of the n closest points to the given point. """ if pdb_id: index = df.index[df["pdb_id"] == pdb_id].item() distances = cdist(embedding[index, np.newaxis], embedding) closest_indices = np.argsort(distances)[0][:n] closest_ids = df.iloc[closest_indices]["pdb_id"].tolist() closest_ids_classifications = df.iloc[closest_indices]["classification"].tolist() return closest_ids, closest_ids_classifications @app.callback( Output("closest-points", "children"), [ Input("embedding-graph", "clickData") ], [ State("algorithm-dropdown", "value"), State("num-components-dropdown", "value"), ] ) def update_closest_points_div( clickData: dict, algorithm: str, num_components: int) -> html.Table: embedding = load_embedding(algorithm, num_components) if clickData is not None: pdb_id, _ = extract_info_from_clickData(clickData) index = df.index[df["pdb_id"] == pdb_id].item() closest_ids, closest_ids_classifications = find_closest_n_points( df, embedding, index) cards = [] for i in range(len(closest_ids)): card = dbc.Card( dbc.CardBody( [ html.P(closest_ids[i], className="card-title"), html.P(closest_ids_classifications[i], className="card-text"), ] ), className="mb-3", ) cards.append(card) return cards return html.Div(id="closest-points", children=[html.Div("Click on a data point to see the closest points.")]) if __name__ == "__main__": app.run_server(debug=False, host='0.0.0.0', port=7680)