File size: 4,224 Bytes
fcd4a61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24ff603
fcd4a61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cfc3bd
fcd4a61
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
## LIBRARIES ###
## Data
import pandas as pd
pd.options.display.float_format = '${:,.2f}'.format

# Analysis

# App & Visualization
import streamlit as st
from bokeh.models import CustomJS, ColumnDataSource, TextInput, DataTable, TableColumn
from bokeh.plotting import figure
from bokeh.transform import factor_cmap
from bokeh.palettes import Category20c_20
from bokeh.layouts import column, row

# utils

def datasets_explorer_viz(df):
    s = ColumnDataSource(df)
    TOOLTIPS= [("dataset_id", "@dataset_id"), ("text", "@text")]
    color = factor_cmap('dataset_id', palette=Category20c_20, factors=df['dataset_id'].unique()) 
    p = figure(plot_width=1000, plot_height=800, tools="hover,wheel_zoom,pan,box_select", tooltips=TOOLTIPS, toolbar_location="above")
    p.scatter('x', 'y', size=5, source=s, alpha=0.8,marker='circle',fill_color = color, line_color=color, legend_field = 'dataset_id')
    p.legend.location = "bottom_right"
    p.legend.click_policy="mute"
    p.legend.label_text_font_size="8pt"
    table_source = ColumnDataSource(data=dict())
    selection_source = ColumnDataSource(data=dict())
    columns = [
        # TableColumn(field="x", title="X data"),
        # TableColumn(field="y", title="Y data"),
        TableColumn(field="dataset_id", title="Dataset ID"),
        TableColumn(field="text", title="Text"),
    ]
    data_table = DataTable(source=table_source, columns=columns, width=600)
    p.circle('x', 'y',source=selection_source, size=5, color= 'red')
    s.selected.js_on_change('indices', CustomJS(args=dict(umap_source=s, table_source=table_source), code="""
            const inds = cb_obj.indices;
            const tableData = table_source.data;
            const umapData = umap_source.data;

            tableData['text'] = []
            tableData['dataset_id'] = []

            for (let i = 0; i < inds.length; i++) {
                tableData['text'].push(umapData['text'][inds[i]])
                tableData['dataset_id'].push(umapData['dataset_id'][inds[i]])
            }
            table_source.data = tableData;
            table_source.change.emit();
    """
    ))
    text_input = TextInput(value="", title="Search")

    text_input.js_on_change('value', CustomJS(args=dict(plot_source=s, selection_source=selection_source), code="""
        const plot_data = plot_source.data;
        const selectData = selection_source.data
        const value = cb_obj.value

        selectData['x'] = []
        selectData['y'] = []
        selectData['dataset_id'] = []
        selectData['text'] = []

        for (var i = 0; i < plot_data['dataset_id'].length; i++) {
            if (plot_data['dataset_id'][i].includes(value) || plot_data['text'][i].includes(value)) {
                selectData['x'].push(plot_data['x'][i])
                selectData['y'].push(plot_data['y'][i])
                selectData['dataset_id'].push(plot_data['dataset_id'][i])
                selectData['text'].push(plot_data['text'][i])
            }
        }
        selection_source.change.emit()
    """))
    
    st.bokeh_chart(row(column(text_input,p), data_table))


if __name__ == "__main__":
    ### STREAMLIT APP CONGFIG ###
    st.set_page_config(layout="wide", page_title="Datapoints Explorer")
    st.title('Interactive Datapoints Explorer for Text Classification')
    #lcol, rcol = st.columns([2, 2])
    # ******* loading the mode and the data

    ### LOAD DATA AND SESSION VARIABLES ###
    with st.expander("How to interact with the plot:"):
        st.markdown("* Each point in the plot represents an example from the HF hub text classification datasets.")
        st.markdown("* The datapoints are emebdded using sentence embeddings of their `text` field.")   
        st.markdown("* You can either search for a datapoint or drag and select to peek into the cluster content.")
        st.markdown("* If the term you are searching for matches `dataset_id` or `text` it will be highlighted in *red*. The selected points will be summarized as a dataframe on the right.")
    datasets_df = pd.read_parquet('./assets/data/datapoints_embeddings_df.parquet')
    st.warning("Hugging Face 🤗  Datapoints Explorer for Text Classification")
    datasets_explorer_viz(datasets_df)