edugp's picture
Sync with data tooling repo, using edugp/kenlm models, updating viz to use quantiles for coloring and ad-hoc viz for the registry dataset
3c30fa3
import logging
from functools import partial
import streamlit as st
from embedding_lenses.data import uploaded_file_to_dataframe
from embedding_lenses.dimensionality_reduction import (
get_tsne_embeddings,
get_umap_embeddings,
)
from embedding_lenses.embedding import load_model
from perplexity_lenses import REGISTRY_DATASET
from perplexity_lenses.data import (
documents_df_to_sentences_df,
hub_dataset_to_dataframe,
)
from perplexity_lenses.engine import (
DIMENSIONALITY_REDUCTION_ALGORITHMS,
DOCUMENT_TYPES,
EMBEDDING_MODELS,
LANGUAGES,
PERPLEXITY_MODELS,
SEED,
generate_plot,
)
from perplexity_lenses.perplexity import KenlmModel
from perplexity_lenses.visualization import draw_histogram
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
st.title("Perplexity Lenses")
st.write("Visualize text embeddings in 2D using colors to represent perplexity values.")
uploaded_file = st.file_uploader("Choose an csv/tsv file...", type=["csv", "tsv"])
st.write(
"Alternatively, select a dataset from the [hub](https://huggingface.co/datasets)"
)
col1, col2, col3 = st.columns(3)
with col1:
hub_dataset = st.text_input("Dataset name", "mc4")
with col2:
hub_dataset_config = st.text_input("Dataset configuration", "es")
with col3:
hub_dataset_split = st.text_input("Dataset split", "train")
col4, col5 = st.columns(2)
with col4:
text_column = st.text_input("Text field name", "text")
with col5:
language = st.selectbox("Language", LANGUAGES, 12)
col6, col7 = st.columns(2)
with col6:
doc_type = st.selectbox("Document type", DOCUMENT_TYPES, 1)
with col7:
sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
perplexity_model = st.selectbox(
"Dataset on which the perplexity model was trained on", PERPLEXITY_MODELS, 0
).lower()
dimensionality_reduction = st.selectbox(
"Dimensionality Reduction algorithm", DIMENSIONALITY_REDUCTION_ALGORITHMS, 0
)
model_name = st.selectbox("Sentence embedding model", EMBEDDING_MODELS, 0)
advanced_options = st.checkbox(
"Advanced options (do not modify if using default KenLM models).", value=False
)
lower_case = True
remove_accents = True
normalize_numbers = True
punctuation = 1
if advanced_options:
lower_case = st.checkbox(
"Lower case text for KenLM preprocessing (from cc_net)", value=False
)
remove_accents = st.checkbox(
"Remove accents for KenLM preprocessing (from cc_net)", value=False
)
normalize_numbers = st.checkbox(
"Replace numbers with zeros KenLM preprocessing (from cc_net)", value=True
)
punctuation = st.number_input(
"Punctuation mode to use from cc_net KenLM preprocessing", 1, 2, 1
)
with st.spinner(text="Loading embedding model..."):
model = load_model(model_name)
dimensionality_reduction_function = (
partial(get_umap_embeddings, random_state=SEED)
if dimensionality_reduction == "UMAP"
else partial(get_tsne_embeddings, random_state=SEED)
)
with st.spinner(text="Loading KenLM model..."):
kenlm_model = KenlmModel.from_pretrained(
perplexity_model,
language,
lower_case,
remove_accents,
normalize_numbers,
punctuation,
)
if uploaded_file or hub_dataset:
with st.spinner("Loading dataset..."):
if uploaded_file:
df = uploaded_file_to_dataframe(uploaded_file)
if doc_type == "Sentence":
df = documents_df_to_sentences_df(df, text_column, sample, seed=SEED)
df["perplexity"] = df[text_column].map(kenlm_model.get_perplexity)
else:
df = hub_dataset_to_dataframe(
hub_dataset,
hub_dataset_config,
hub_dataset_split,
sample,
text_column,
kenlm_model,
seed=SEED,
doc_type=doc_type,
)
# Round perplexity
df["perplexity"] = df["perplexity"].round().astype(int)
logger.info(
f"Perplexity range: {df['perplexity'].min()} - {df['perplexity'].max()}"
)
plot, plot_registry = generate_plot(
df,
text_column,
"perplexity",
None,
dimensionality_reduction_function,
model,
seed=SEED,
context_logger=st.spinner,
hub_dataset=hub_dataset,
)
logger.info("Displaying plots")
st.bokeh_chart(plot)
if hub_dataset == REGISTRY_DATASET:
st.bokeh_chart(plot_registry)
fig = draw_histogram(df["perplexity"].values)
st.pyplot(fig)
logger.info("Done")