import logging from functools import partial from typing import Callable, List, Optional import numpy as np import pandas as pd import streamlit as st import umap from bokeh.models import ColumnDataSource, HoverTool from bokeh.palettes import Cividis256 as Pallete from bokeh.plotting import Figure, figure from bokeh.transform import factor_cmap from datasets import load_dataset from sentence_transformers import SentenceTransformer from sklearn.manifold import TSNE logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) EMBEDDING_MODELS = ["distiluse-base-multilingual-cased-v1", "all-mpnet-base-v2", "flax-sentence-embeddings/all_datasets_v3_mpnet-base"] DIMENSIONALITY_REDUCTION_ALGORITHMS = ["UMAP", "t-SNE"] SEED = 0 @st.cache(show_spinner=False, allow_output_mutation=True) def load_model(model_name: str) -> SentenceTransformer: embedder = model_name return SentenceTransformer(embedder) def embed_text(text: List[str], model: SentenceTransformer) -> np.ndarray: return model.encode(text) def encode_labels(labels: pd.Series) -> pd.Series: if pd.api.types.is_numeric_dtype(labels): return labels return labels.astype("category").cat.codes def get_tsne_embeddings( embeddings: np.ndarray, perplexity: int = 30, n_components: int = 2, init: str = "pca", n_iter: int = 5000, random_state: int = SEED ) -> np.ndarray: tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state) return tsne.fit_transform(embeddings) def get_umap_embeddings(embeddings: np.ndarray) -> np.ndarray: umap_model = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=SEED) return umap_model.fit_transform(embeddings) def draw_interactive_scatter_plot( texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str ) -> Figure: # Normalize values to range between 0-255, to assign a color for each value max_value = values.max() min_value = values.min() if max_value - min_value == 0: values_color = np.ones(len(values)) else: values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str) values_color_set = sorted(values_color) values_list = values.astype(str).tolist() values_set = sorted(values_list) labels_list = labels.astype(str).tolist() source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels_list)) hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")]) p = figure(plot_width=800, plot_height=800, tools=[hover]) p.circle("x", "y", size=10, source=source, fill_color=factor_cmap("label", palette=[Pallete[int(id_)] for id_ in values_color_set], factors=values_set)) p.axis.visible = False p.xgrid.grid_line_color = None p.ygrid.grid_line_color = None p.toolbar.logo = None return p def uploaded_file_to_dataframe(uploaded_file: st.uploaded_file_manager.UploadedFile) -> pd.DataFrame: extension = uploaded_file.name.split(".")[-1] return pd.read_csv(uploaded_file, sep="\t" if extension == "tsv" else ",") def hub_dataset_to_dataframe(path: str, name: str, split: str, sample: int) -> pd.DataFrame: load_dataset_fn = partial(load_dataset, path=path) if name: load_dataset_fn = partial(load_dataset_fn, name=name) if split: load_dataset_fn = partial(load_dataset_fn, split=split) dataset = load_dataset_fn().shuffle(seed=SEED)[:sample] return pd.DataFrame(dataset) def generate_plot( df: pd.DataFrame, text_column: str, label_column: str, sample: Optional[int], dimensionality_reduction_function: Callable, model: SentenceTransformer, ) -> Figure: if text_column not in df.columns: raise ValueError(f"The specified column name doesn't exist. Columns available: {df.columns.values}") if label_column not in df.columns: df[label_column] = 0 df = df.dropna(subset=[text_column, label_column]) if sample: df = df.sample(min(sample, df.shape[0]), random_state=SEED) with st.spinner(text="Embedding text..."): embeddings = embed_text(df[text_column].values.tolist(), model) logger.info("Encoding labels") encoded_labels = encode_labels(df[label_column]) with st.spinner("Reducing dimensionality..."): embeddings_2d = dimensionality_reduction_function(embeddings) logger.info("Generating figure") plot = draw_interactive_scatter_plot( df[text_column].values, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels.values, df[label_column].values, text_column, label_column ) return plot st.title("Embedding Lenses") st.write("Visualize text embeddings in 2D using colors for continuous or categorical labels.") uploaded_file = st.file_uploader("Choose an csv/tsv file...", type=["csv", "tsv"]) st.write("Alternatively, select a dataset from the [hub](https://huggingface.co/datasets)") col1, col2, col3 = st.columns(3) with col1: hub_dataset = st.text_input("Dataset name", "ag_news") with col2: hub_dataset_config = st.text_input("Dataset configuration", "") with col3: hub_dataset_split = st.text_input("Dataset split", "train") text_column = st.text_input("Text column name", "text") label_column = st.text_input("Numerical/categorical column name (ignore if not applicable)", "label") sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000) dimensionality_reduction = st.selectbox("Dimensionality Reduction algorithm", DIMENSIONALITY_REDUCTION_ALGORITHMS, 0) model_name = st.selectbox("Sentence embedding model", EMBEDDING_MODELS, 0) with st.spinner(text="Loading model..."): model = load_model(model_name) dimensionality_reduction_function = get_umap_embeddings if dimensionality_reduction == "UMAP" else get_tsne_embeddings if uploaded_file or hub_dataset: with st.spinner("Loading dataset..."): if uploaded_file: df = uploaded_file_to_dataframe(uploaded_file) else: df = hub_dataset_to_dataframe(hub_dataset, hub_dataset_config, hub_dataset_split, sample) plot = generate_plot(df, text_column, label_column, sample, dimensionality_reduction_function, model) logger.info("Displaying plot") st.bokeh_chart(plot) logger.info("Done")