File size: 4,631 Bytes
1f30dbc
 
 
 
 
3c30fa3
 
 
 
86e673e
1f30dbc
3c30fa3
 
 
 
 
 
 
 
 
 
 
 
 
 
a86046b
3c30fa3
1f30dbc
 
 
 
 
 
 
 
7b62017
 
 
1f30dbc
 
 
 
 
 
 
 
a86046b
 
 
 
 
 
 
 
 
 
 
3c30fa3
 
 
a86046b
7b62017
 
 
1f30dbc
 
3c30fa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f30dbc
 
 
7b62017
 
 
1f30dbc
 
 
3c30fa3
 
 
 
 
 
 
 
1f30dbc
 
 
 
 
a86046b
 
abf62cb
1f30dbc
7b62017
 
 
 
 
 
 
 
 
 
a86046b
 
 
7b62017
 
 
3c30fa3
7b62017
 
 
 
 
 
 
 
3c30fa3
7b62017
3c30fa3
1f30dbc
3c30fa3
 
 
 
1f30dbc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import logging
from functools import partial

import streamlit as st
from embedding_lenses.data import uploaded_file_to_dataframe
from embedding_lenses.dimensionality_reduction import (
    get_tsne_embeddings,
    get_umap_embeddings,
)
from embedding_lenses.embedding import load_model

from perplexity_lenses import REGISTRY_DATASET
from perplexity_lenses.data import (
    documents_df_to_sentences_df,
    hub_dataset_to_dataframe,
)
from perplexity_lenses.engine import (
    DIMENSIONALITY_REDUCTION_ALGORITHMS,
    DOCUMENT_TYPES,
    EMBEDDING_MODELS,
    LANGUAGES,
    PERPLEXITY_MODELS,
    SEED,
    generate_plot,
)
from perplexity_lenses.perplexity import KenlmModel
from perplexity_lenses.visualization import draw_histogram

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


st.title("Perplexity Lenses")
st.write("Visualize text embeddings in 2D using colors to represent perplexity values.")
uploaded_file = st.file_uploader("Choose an csv/tsv file...", type=["csv", "tsv"])
st.write(
    "Alternatively, select a dataset from the [hub](https://huggingface.co/datasets)"
)
col1, col2, col3 = st.columns(3)
with col1:
    hub_dataset = st.text_input("Dataset name", "mc4")
with col2:
    hub_dataset_config = st.text_input("Dataset configuration", "es")
with col3:
    hub_dataset_split = st.text_input("Dataset split", "train")

col4, col5 = st.columns(2)
with col4:
    text_column = st.text_input("Text field name", "text")
with col5:
    language = st.selectbox("Language", LANGUAGES, 12)

col6, col7 = st.columns(2)
with col6:
    doc_type = st.selectbox("Document type", DOCUMENT_TYPES, 1)
with col7:
    sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
perplexity_model = st.selectbox(
    "Dataset on which the perplexity model was trained on", PERPLEXITY_MODELS, 0
).lower()

dimensionality_reduction = st.selectbox(
    "Dimensionality Reduction algorithm", DIMENSIONALITY_REDUCTION_ALGORITHMS, 0
)
model_name = st.selectbox("Sentence embedding model", EMBEDDING_MODELS, 0)

advanced_options = st.checkbox(
    "Advanced options (do not modify if using default KenLM models).", value=False
)
lower_case = True
remove_accents = True
normalize_numbers = True
punctuation = 1
if advanced_options:
    lower_case = st.checkbox(
        "Lower case text for KenLM preprocessing (from cc_net)", value=False
    )
    remove_accents = st.checkbox(
        "Remove accents for KenLM preprocessing (from cc_net)", value=False
    )
    normalize_numbers = st.checkbox(
        "Replace numbers with zeros KenLM preprocessing (from cc_net)", value=True
    )
    punctuation = st.number_input(
        "Punctuation mode to use from cc_net KenLM preprocessing", 1, 2, 1
    )

with st.spinner(text="Loading embedding model..."):
    model = load_model(model_name)
dimensionality_reduction_function = (
    partial(get_umap_embeddings, random_state=SEED)
    if dimensionality_reduction == "UMAP"
    else partial(get_tsne_embeddings, random_state=SEED)
)

with st.spinner(text="Loading KenLM model..."):
    kenlm_model = KenlmModel.from_pretrained(
        perplexity_model,
        language,
        lower_case,
        remove_accents,
        normalize_numbers,
        punctuation,
    )

if uploaded_file or hub_dataset:
    with st.spinner("Loading dataset..."):
        if uploaded_file:
            df = uploaded_file_to_dataframe(uploaded_file)
            if doc_type == "Sentence":
                df = documents_df_to_sentences_df(df, text_column, sample, seed=SEED)
            df["perplexity"] = df[text_column].map(kenlm_model.get_perplexity)
        else:
            df = hub_dataset_to_dataframe(
                hub_dataset,
                hub_dataset_config,
                hub_dataset_split,
                sample,
                text_column,
                kenlm_model,
                seed=SEED,
                doc_type=doc_type,
            )

    # Round perplexity
    df["perplexity"] = df["perplexity"].round().astype(int)
    logger.info(
        f"Perplexity range: {df['perplexity'].min()} - {df['perplexity'].max()}"
    )
    plot, plot_registry = generate_plot(
        df,
        text_column,
        "perplexity",
        None,
        dimensionality_reduction_function,
        model,
        seed=SEED,
        context_logger=st.spinner,
        hub_dataset=hub_dataset,
    )
    logger.info("Displaying plots")
    st.bokeh_chart(plot)
    if hub_dataset == REGISTRY_DATASET:
        st.bokeh_chart(plot_registry)
    fig = draw_histogram(df["perplexity"].values)
    st.pyplot(fig)
    logger.info("Done")