edugp's picture
Inital commit for perplexity lenses
1f30dbc
raw
history blame
No virus
4.51 kB
import logging
from functools import partial
from typing import Callable, Optional
import pandas as pd
import streamlit as st
from bokeh.plotting import Figure
from embedding_lenses.data import uploaded_file_to_dataframe
from embedding_lenses.dimensionality_reduction import (get_tsne_embeddings,
get_umap_embeddings)
from embedding_lenses.embedding import embed_text, load_model
from embedding_lenses.utils import encode_labels
from embedding_lenses.visualization import draw_interactive_scatter_plot
from sentence_transformers import SentenceTransformer
from data import hub_dataset_to_dataframe
from perplexity import KenlmModel
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
EMBEDDING_MODELS = ["distiluse-base-multilingual-cased-v1", "all-mpnet-base-v2", "flax-sentence-embeddings/all_datasets_v3_mpnet-base"]
DIMENSIONALITY_REDUCTION_ALGORITHMS = ["UMAP", "t-SNE"]
LANGUAGES = [
"af",
"ar",
"az",
"be",
"bg",
"bn",
"ca",
"cs",
"da",
"de",
"el",
"en",
"es",
"et",
"fa",
"fi",
"fr",
"gu",
"he",
"hi",
"hr",
"hu",
"hy",
"id",
"is",
"it",
"ja",
"ka",
"kk",
"km",
"kn",
"ko",
"lt",
"lv",
"mk",
"ml",
"mn",
"mr",
"my",
"ne",
"nl",
"no",
"pl",
"pt",
"ro",
"ru",
"uk",
"zh",
]
SEED = 0
def generate_plot(
df: pd.DataFrame,
text_column: str,
label_column: str,
sample: Optional[int],
dimensionality_reduction_function: Callable,
model: SentenceTransformer,
) -> Figure:
if text_column not in df.columns:
raise ValueError(f"The specified column name doesn't exist. Columns available: {df.columns.values}")
if label_column not in df.columns:
df[label_column] = 0
df = df.dropna(subset=[text_column, label_column])
if sample:
df = df.sample(min(sample, df.shape[0]), random_state=SEED)
with st.spinner(text="Embedding text..."):
embeddings = embed_text(df[text_column].values.tolist(), model)
logger.info("Encoding labels")
encoded_labels = encode_labels(df[label_column])
with st.spinner("Reducing dimensionality..."):
embeddings_2d = dimensionality_reduction_function(embeddings)
logger.info("Generating figure")
plot = draw_interactive_scatter_plot(
df[text_column].values, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels.values, df[label_column].values, text_column, label_column
)
return plot
st.title("Perplexity Lenses")
st.write("Visualize text embeddings in 2D using colors to represent perplexity values.")
uploaded_file = st.file_uploader("Choose an csv/tsv file...", type=["csv", "tsv"])
st.write("Alternatively, select a dataset from the [hub](https://huggingface.co/datasets)")
col1, col2, col3 = st.columns(3)
with col1:
hub_dataset = st.text_input("Dataset name", "mc4")
with col2:
hub_dataset_config = st.text_input("Dataset configuration", "es")
with col3:
hub_dataset_split = st.text_input("Dataset split", "train")
text_column = st.text_input("Text column name", "text")
language = st.selectbox("Language", LANGUAGES, 12)
sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
dimensionality_reduction = st.selectbox("Dimensionality Reduction algorithm", DIMENSIONALITY_REDUCTION_ALGORITHMS, 0)
model_name = st.selectbox("Sentence embedding model", EMBEDDING_MODELS, 0)
with st.spinner(text="Loading embedding model..."):
model = load_model(model_name)
dimensionality_reduction_function = (
partial(get_umap_embeddings, random_state=SEED) if dimensionality_reduction == "UMAP" else partial(get_tsne_embeddings, random_state=SEED)
)
with st.spinner(text="Loading KenLM model..."):
kenlm_model = KenlmModel.from_pretrained(language)
if uploaded_file or hub_dataset:
with st.spinner("Loading dataset..."):
if uploaded_file:
df = uploaded_file_to_dataframe(uploaded_file)
df["perplexity"] = df[text_column].map(lambda x: model.get_perplexity(x[text_column]))
else:
df = hub_dataset_to_dataframe(hub_dataset, hub_dataset_config, hub_dataset_split, sample, text_column, kenlm_model, seed=SEED)
plot = generate_plot(df, text_column, "perplexity", sample, dimensionality_reduction_function, model)
logger.info("Displaying plot")
st.bokeh_chart(plot)
logger.info("Done")