import logging from functools import partial import streamlit as st from embedding_lenses.data import uploaded_file_to_dataframe from embedding_lenses.dimensionality_reduction import ( get_tsne_embeddings, get_umap_embeddings, ) from embedding_lenses.embedding import load_model from perplexity_lenses import REGISTRY_DATASET from perplexity_lenses.data import ( documents_df_to_sentences_df, hub_dataset_to_dataframe, ) from perplexity_lenses.engine import ( DIMENSIONALITY_REDUCTION_ALGORITHMS, DOCUMENT_TYPES, EMBEDDING_MODELS, LANGUAGES, PERPLEXITY_MODELS, SEED, generate_plot, ) from perplexity_lenses.perplexity import KenlmModel from perplexity_lenses.visualization import draw_histogram logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) st.title("Perplexity Lenses") st.write("Visualize text embeddings in 2D using colors to represent perplexity values.") uploaded_file = st.file_uploader("Choose an csv/tsv file...", type=["csv", "tsv"]) st.write( "Alternatively, select a dataset from the [hub](https://huggingface.co/datasets)" ) col1, col2, col3 = st.columns(3) with col1: hub_dataset = st.text_input("Dataset name", "mc4") with col2: hub_dataset_config = st.text_input("Dataset configuration", "es") with col3: hub_dataset_split = st.text_input("Dataset split", "train") col4, col5 = st.columns(2) with col4: text_column = st.text_input("Text field name", "text") with col5: language = st.selectbox("Language", LANGUAGES, 12) col6, col7 = st.columns(2) with col6: doc_type = st.selectbox("Document type", DOCUMENT_TYPES, 1) with col7: sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000) perplexity_model = st.selectbox( "Dataset on which the perplexity model was trained on", PERPLEXITY_MODELS, 0 ).lower() dimensionality_reduction = st.selectbox( "Dimensionality Reduction algorithm", DIMENSIONALITY_REDUCTION_ALGORITHMS, 0 ) model_name = st.selectbox("Sentence embedding model", EMBEDDING_MODELS, 0) advanced_options = st.checkbox( "Advanced options (do not modify if using default KenLM models).", value=False ) lower_case = True remove_accents = True normalize_numbers = True punctuation = 1 if advanced_options: lower_case = st.checkbox( "Lower case text for KenLM preprocessing (from cc_net)", value=False ) remove_accents = st.checkbox( "Remove accents for KenLM preprocessing (from cc_net)", value=False ) normalize_numbers = st.checkbox( "Replace numbers with zeros KenLM preprocessing (from cc_net)", value=True ) punctuation = st.number_input( "Punctuation mode to use from cc_net KenLM preprocessing", 1, 2, 1 ) with st.spinner(text="Loading embedding model..."): model = load_model(model_name) dimensionality_reduction_function = ( partial(get_umap_embeddings, random_state=SEED) if dimensionality_reduction == "UMAP" else partial(get_tsne_embeddings, random_state=SEED) ) with st.spinner(text="Loading KenLM model..."): kenlm_model = KenlmModel.from_pretrained( perplexity_model, language, lower_case, remove_accents, normalize_numbers, punctuation, ) if uploaded_file or hub_dataset: with st.spinner("Loading dataset..."): if uploaded_file: df = uploaded_file_to_dataframe(uploaded_file) if doc_type == "Sentence": df = documents_df_to_sentences_df(df, text_column, sample, seed=SEED) df["perplexity"] = df[text_column].map(kenlm_model.get_perplexity) else: df = hub_dataset_to_dataframe( hub_dataset, hub_dataset_config, hub_dataset_split, sample, text_column, kenlm_model, seed=SEED, doc_type=doc_type, ) # Round perplexity df["perplexity"] = df["perplexity"].round().astype(int) logger.info( f"Perplexity range: {df['perplexity'].min()} - {df['perplexity'].max()}" ) plot, plot_registry = generate_plot( df, text_column, "perplexity", None, dimensionality_reduction_function, model, seed=SEED, context_logger=st.spinner, hub_dataset=hub_dataset, ) logger.info("Displaying plots") st.bokeh_chart(plot) if hub_dataset == REGISTRY_DATASET: st.bokeh_chart(plot_registry) fig = draw_histogram(df["perplexity"].values) st.pyplot(fig) logger.info("Done")