paccmann / plots.py
jannisborn's picture
update
ec53722 unverified
raw
history blame
2.53 kB
"""Plotting utilities."""
import numpy as np
from typing import Tuple
from bokeh.layouts import column
from bokeh.models import CustomJS, Slider
from bokeh.plotting import figure, Figure, ColumnDataSource
from bokeh.embed import components
def barplot(attended: np.ndarray, weights: np.ndarray) -> Figure:
"""
Bokeh barplot showing top k attention weights.
k is interactively changable via a slider.
Args:
attended (np.ndarray): Names of the attended entities
weights (np.ndarray): Attention weights
Returns:
bokeh.plotting.Figure: Can be visualized for debugging,
via bokeh.plotting (i.e. output_file, show)
"""
K = 4
# reset from slider callback
source = ColumnDataSource(
data=dict(attended=attended, weights=weights),
)
top_k_slider = Slider(start=1, end=len(attended), value=K, step=1, title="k")
p = figure(
x_range=source.data["attended"][:K], # adapted by callback
plot_height=350,
title="Top k Gene Attention Weights",
toolbar_location="below",
tools="pan,wheel_zoom,box_zoom,save,reset",
)
p.vbar(x="attended", top="weights", source=source, width=0.9)
# define the callback
callback = CustomJS(
args=dict(
source=source,
xrange=p.x_range,
yrange=p.y_range,
attended=attended,
weights=weights,
top_k=top_k_slider,
),
code="""
var data = source.data;
const k = top_k.value;
data['attended'] = attended.slice(0, k)
data['weights'] = weights.slice(0, k)
source.change.emit();
// not need if data is in descending order
var yrange_arr = data['weights'];
var yrange_max = Math.max(...yrange_arr) * 1.05;
yrange.end = yrange_max;
xrange.factors = data['attended'];
source.change.emit();
""",
)
top_k_slider.js_on_change("value", callback)
layout = column(top_k_slider, p)
p.xgrid.grid_line_color = None
p.y_range.start = 0
return layout
def embed_barplot(attended: np.ndarray, weights: np.ndarray) -> Tuple[str, str]:
"""Bokeh barplot showing top k attention weights.
k is interactively changable via a slider.
Args:
attended (np.ndarray): Names of the attended entities
weights (np.ndarray): Attention weights
Returns:
Tuple[str, str]: javascript and html
"""
return components(barplot(attended, weights))