Spaces:
Running
Running
"""Plotting utilities.""" | |
import numpy as np | |
from typing import Tuple | |
from bokeh.layouts import column | |
from bokeh.models import CustomJS, Slider | |
from bokeh.plotting import figure, Figure, ColumnDataSource | |
from bokeh.embed import components | |
def barplot(attended: np.ndarray, weights: np.ndarray) -> Figure: | |
""" | |
Bokeh barplot showing top k attention weights. | |
k is interactively changable via a slider. | |
Args: | |
attended (np.ndarray): Names of the attended entities | |
weights (np.ndarray): Attention weights | |
Returns: | |
bokeh.plotting.Figure: Can be visualized for debugging, | |
via bokeh.plotting (i.e. output_file, show) | |
""" | |
K = 4 | |
# reset from slider callback | |
source = ColumnDataSource( | |
data=dict(attended=attended, weights=weights), | |
) | |
top_k_slider = Slider(start=1, end=len(attended), value=K, step=1, title="k") | |
p = figure( | |
x_range=source.data["attended"][:K], # adapted by callback | |
plot_height=350, | |
title="Top k Gene Attention Weights", | |
toolbar_location="below", | |
tools="pan,wheel_zoom,box_zoom,save,reset", | |
) | |
p.vbar(x="attended", top="weights", source=source, width=0.9) | |
# define the callback | |
callback = CustomJS( | |
args=dict( | |
source=source, | |
xrange=p.x_range, | |
yrange=p.y_range, | |
attended=attended, | |
weights=weights, | |
top_k=top_k_slider, | |
), | |
code=""" | |
var data = source.data; | |
const k = top_k.value; | |
data['attended'] = attended.slice(0, k) | |
data['weights'] = weights.slice(0, k) | |
source.change.emit(); | |
// not need if data is in descending order | |
var yrange_arr = data['weights']; | |
var yrange_max = Math.max(...yrange_arr) * 1.05; | |
yrange.end = yrange_max; | |
xrange.factors = data['attended']; | |
source.change.emit(); | |
""", | |
) | |
top_k_slider.js_on_change("value", callback) | |
layout = column(top_k_slider, p) | |
p.xgrid.grid_line_color = None | |
p.y_range.start = 0 | |
return layout | |
def embed_barplot(attended: np.ndarray, weights: np.ndarray) -> Tuple[str, str]: | |
"""Bokeh barplot showing top k attention weights. | |
k is interactively changable via a slider. | |
Args: | |
attended (np.ndarray): Names of the attended entities | |
weights (np.ndarray): Attention weights | |
Returns: | |
Tuple[str, str]: javascript and html | |
""" | |
return components(barplot(attended, weights)) | |