|
import argparse |
|
import logging |
|
from typing import Any, Optional |
|
|
|
import bokeh |
|
import numpy as np |
|
import pandas as pd |
|
from bokeh.models import ColumnDataSource, HoverTool |
|
from bokeh.plotting import figure, output_file, save |
|
from bokeh.transform import factor_cmap |
|
from bokeh.palettes import Cividis256 as Pallete |
|
from sklearn.manifold import TSNE |
|
|
|
|
|
logging.basicConfig(level = logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
SEED = 0 |
|
|
|
def get_tsne_embeddings(embeddings: np.ndarray, perplexity: int=30, n_components: int=2, init: str='pca', n_iter: int=5000, random_state: int=SEED) -> np.ndarray: |
|
tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state) |
|
return tsne.fit_transform(embeddings) |
|
|
|
def draw_interactive_scatter_plot(texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray) -> Any: |
|
|
|
max_value = values.max() |
|
min_value = values.min() |
|
values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str) |
|
values_color_set = sorted(values_color) |
|
|
|
values_list = values.astype(str).tolist() |
|
values_set = sorted(values_list) |
|
|
|
source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, perplexity=values_list)) |
|
hover = HoverTool(tooltips=[('Sentence', '@text{safe}'), ('Perplexity', '@perplexity')]) |
|
p = figure(plot_width=1200, plot_height=1200, tools=[hover], title='Sentences') |
|
p.circle( |
|
'x', 'y', size=10, source=source, fill_color=factor_cmap('perplexity', palette=[Pallete[int(id_)] for id_ in values_color_set], factors=values_set)) |
|
return p |
|
|
|
def generate_plot(tsv: str, output_file_name: str, sample: Optional[int]): |
|
logger.info("Loading dataset in memory") |
|
df = pd.read_csv(tsv, sep="\t") |
|
if sample: |
|
df = df.sample(sample, random_state=SEED) |
|
logger.info(f"Dataset contains {df.shape[0]} sentences") |
|
embeddings = df[sorted([col for col in df.columns if col.startswith("dim")], key=lambda x: int(x.split("_")[-1]))].values |
|
logger.info(f"Running t-SNE") |
|
tsne_embeddings = get_tsne_embeddings(embeddings) |
|
logger.info(f"Generating figure") |
|
plot = draw_interactive_scatter_plot(df["sentence"].values, tsne_embeddings[:, 0], tsne_embeddings[:, 1], df["perplexity"].values) |
|
output_file(output_file_name) |
|
save(plot) |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Embeddings t-SNE plot") |
|
parser.add_argument("--tsv", type=str, help="Path to tsv file with columns 'text', 'perplexity' and N 'dim_<i> columns for each embdeding dimension.'") |
|
parser.add_argument("--output_file", type=str, help="Path to the output HTML file for the interactive plot.", default="perplexity_colored_embeddings.html") |
|
parser.add_argument("--sample", type=int, help="Number of sentences to use", default=None) |
|
|
|
args = parser.parse_args() |
|
generate_plot(args.tsv, args.output_file, args.sample) |
|
|