File size: 2,534 Bytes
ec53722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""Plotting utilities."""
import numpy as np
from typing import Tuple
from bokeh.layouts import column
from bokeh.models import CustomJS, Slider
from bokeh.plotting import figure, Figure, ColumnDataSource
from bokeh.embed import components


def barplot(attended: np.ndarray, weights: np.ndarray) -> Figure:
    """
    Bokeh barplot showing top k attention weights.

    k is interactively changable via a slider.

    Args:
        attended (np.ndarray): Names of the attended entities
        weights (np.ndarray): Attention weights

    Returns:
        bokeh.plotting.Figure: Can be visualized for debugging,
            via bokeh.plotting (i.e. output_file, show)
    """
    K = 4
    # reset from slider callback
    source = ColumnDataSource(
        data=dict(attended=attended, weights=weights),
    )
    top_k_slider = Slider(start=1, end=len(attended), value=K, step=1, title="k")
    p = figure(
        x_range=source.data["attended"][:K],  # adapted by callback
        plot_height=350,
        title="Top k Gene Attention Weights",
        toolbar_location="below",
        tools="pan,wheel_zoom,box_zoom,save,reset",
    )
    p.vbar(x="attended", top="weights", source=source, width=0.9)
    # define the callback
    callback = CustomJS(
        args=dict(
            source=source,
            xrange=p.x_range,
            yrange=p.y_range,
            attended=attended,
            weights=weights,
            top_k=top_k_slider,
        ),
        code="""
        var data = source.data;
        const k = top_k.value;

        data['attended'] = attended.slice(0, k)
        data['weights'] = weights.slice(0, k)

        source.change.emit();

        // not need if data is in descending order
        var yrange_arr = data['weights'];
        var yrange_max = Math.max(...yrange_arr) * 1.05;
        yrange.end = yrange_max;

        xrange.factors = data['attended'];

        source.change.emit();
    """,
    )
    top_k_slider.js_on_change("value", callback)
    layout = column(top_k_slider, p)
    p.xgrid.grid_line_color = None
    p.y_range.start = 0
    return layout


def embed_barplot(attended: np.ndarray, weights: np.ndarray) -> Tuple[str, str]:
    """Bokeh barplot showing top k attention weights.
    k is interactively changable via a slider.


    Args:
        attended (np.ndarray): Names of the attended entities
        weights (np.ndarray): Attention weights

    Returns:
        Tuple[str, str]: javascript and html
    """
    return components(barplot(attended, weights))