edugp's picture
Sync with data tooling repo, using edugp/kenlm models, updating viz to use quantiles for coloring and ad-hoc viz for the registry dataset
3c30fa3
import logging
import time
from typing import Callable, Optional, Tuple, Union
import pandas as pd
import streamlit as st
from bokeh.palettes import Turbo256
from bokeh.plotting import Figure
from embedding_lenses.embedding import embed_text
from embedding_lenses.utils import encode_labels
from embedding_lenses.visualization import draw_interactive_scatter_plot
from sentence_transformers import SentenceTransformer
from perplexity_lenses import REGISTRY_DATASET
logger = logging.getLogger(__name__)
EMBEDDING_MODELS = [
"distiluse-base-multilingual-cased-v1",
"distiluse-base-multilingual-cased-v2",
"all-mpnet-base-v2",
"flax-sentence-embeddings/all_datasets_v3_mpnet-base",
]
DIMENSIONALITY_REDUCTION_ALGORITHMS = ["UMAP", "t-SNE"]
DOCUMENT_TYPES = ["Whole document", "Sentence"]
SEED = 0
LANGUAGES = [
"af",
"ar",
"az",
"be",
"bg",
"bn",
"ca",
"cs",
"da",
"de",
"el",
"en",
"es",
"et",
"fa",
"fi",
"fr",
"gu",
"he",
"hi",
"hr",
"hu",
"hy",
"id",
"is",
"it",
"ja",
"ka",
"kk",
"km",
"kn",
"ko",
"lt",
"lv",
"mk",
"ml",
"mn",
"mr",
"my",
"ne",
"nl",
"no",
"pl",
"pt",
"ro",
"ru",
"uk",
"zh",
]
PERPLEXITY_MODELS = ["Wikipedia", "OSCAR"]
class ContextLogger:
def __init__(self, text: str = ""):
self.text = text
self.start_time = time.time()
def __enter__(self):
logger.info(self.text)
def __exit__(self, type, value, traceback):
logger.info(f"Took: {time.time() - self.start_time:.4f} seconds")
def generate_plot(
df: pd.DataFrame,
text_column: str,
label_column: str,
sample: Optional[int],
dimensionality_reduction_function: Callable,
model: SentenceTransformer,
seed: int = 0,
context_logger: Union[st.spinner, ContextLogger] = ContextLogger,
hub_dataset: str = "",
) -> Tuple[Figure, Optional[Figure]]:
if text_column not in df.columns:
raise ValueError(
f"The specified column name doesn't exist. Columns available: {df.columns.values}"
)
if label_column not in df.columns:
df[label_column] = 0
df = df.dropna(subset=[text_column, label_column])
if sample:
df = df.sample(min(sample, df.shape[0]), random_state=seed)
with context_logger(text="Embedding text..."):
embeddings = embed_text(df[text_column].values.tolist(), model)
logger.info("Encoding labels")
encoded_labels = encode_labels(df[label_column])
with context_logger("Reducing dimensionality..."):
embeddings_2d = dimensionality_reduction_function(embeddings)
logger.info("Generating figure")
hover_data = {
text_column: df[text_column].values,
label_column: encoded_labels.values,
}
# Round perplexity values
values = df[label_column].values.round().astype(int)
plot = draw_interactive_scatter_plot(
hover_data,
embeddings_2d[:, 0],
embeddings_2d[:, 1],
values,
)
# Special case for the registry dataset
plot_registry = None
if hub_dataset == REGISTRY_DATASET:
encoded_labels = encode_labels(df["label"])
hover_data = {
text_column: df[text_column].values,
"label": df["label"].values,
label_column: df[label_column].values,
}
plot_registry = draw_interactive_scatter_plot(
hover_data,
embeddings_2d[:, 0],
embeddings_2d[:, 1],
encoded_labels.values,
palette=Turbo256,
)
return plot, plot_registry