edugp's picture
Add CLI and refactor
86e673e
raw history blame
No virus
3.06 kB
import logging
from functools import partial
import streamlit as st
from embedding_lenses.data import uploaded_file_to_dataframe
from embedding_lenses.dimensionality_reduction import get_tsne_embeddings, get_umap_embeddings
from embedding_lenses.embedding import load_model
from perplexity_lenses.data import documents_df_to_sentences_df, hub_dataset_to_dataframe
from perplexity_lenses.engine import DIMENSIONALITY_REDUCTION_ALGORITHMS, DOCUMENT_TYPES, EMBEDDING_MODELS, LANGUAGES, SEED, generate_plot
from perplexity_lenses.perplexity import KenlmModel
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
st.title("Perplexity Lenses")
st.write("Visualize text embeddings in 2D using colors to represent perplexity values.")
uploaded_file = st.file_uploader("Choose an csv/tsv file...", type=["csv", "tsv"])
st.write("Alternatively, select a dataset from the [hub](https://huggingface.co/datasets)")
col1, col2, col3 = st.columns(3)
with col1:
hub_dataset = st.text_input("Dataset name", "mc4")
with col2:
hub_dataset_config = st.text_input("Dataset configuration", "es")
with col3:
hub_dataset_split = st.text_input("Dataset split", "train")
col4, col5 = st.columns(2)
with col4:
text_column = st.text_input("Text field name", "text")
with col5:
language = st.selectbox("Language", LANGUAGES, 12)
col6, col7 = st.columns(2)
with col6:
doc_type = st.selectbox("Document type", DOCUMENT_TYPES, 1)
with col7:
sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
dimensionality_reduction = st.selectbox("Dimensionality Reduction algorithm", DIMENSIONALITY_REDUCTION_ALGORITHMS, 0)
model_name = st.selectbox("Sentence embedding model", EMBEDDING_MODELS, 0)
with st.spinner(text="Loading embedding model..."):
model = load_model(model_name)
dimensionality_reduction_function = (
partial(get_umap_embeddings, random_state=SEED) if dimensionality_reduction == "UMAP" else partial(get_tsne_embeddings, random_state=SEED)
)
with st.spinner(text="Loading KenLM model..."):
kenlm_model = KenlmModel.from_pretrained(language)
if uploaded_file or hub_dataset:
with st.spinner("Loading dataset..."):
if uploaded_file:
df = uploaded_file_to_dataframe(uploaded_file)
if doc_type == "Sentence":
df = documents_df_to_sentences_df(df, text_column, sample, seed=SEED)
df["perplexity"] = df[text_column].map(kenlm_model.get_perplexity)
else:
df = hub_dataset_to_dataframe(hub_dataset, hub_dataset_config, hub_dataset_split, sample, text_column, kenlm_model, seed=SEED, doc_type=doc_type)
# Round perplexity
df["perplexity"] = df["perplexity"].round().astype(int)
logger.info(f"Perplexity range: {df['perplexity'].min()} - {df['perplexity'].max()}")
plot = generate_plot(df, text_column, "perplexity", None, dimensionality_reduction_function, model, seed=SEED, context_logger=st.spinner)
logger.info("Displaying plot")
st.bokeh_chart(plot)
logger.info("Done")