edugp's picture
Run tokenizer before computing perplexity and format
7b62017
import numpy as np
from bokeh.models import ColumnDataSource, HoverTool
from bokeh.palettes import Cividis256 as Pallete
from bokeh.plotting import Figure, figure
from bokeh.transform import factor_cmap
def draw_interactive_scatter_plot(
texts: np.ndarray,
xs: np.ndarray,
ys: np.ndarray,
values: np.ndarray,
labels: np.ndarray,
text_column: str,
label_column: str,
) -> Figure:
# Smooth down values for coloring, by taking the entropy = log10(perplexity) and multiply it by 10000
values = ((np.log10(values)) * 10000).round().astype(int)
# Normalize values to range between 0-255, to assign a color for each value
max_value = values.max()
min_value = values.min()
if max_value - min_value == 0:
values_color = np.ones(len(values))
else:
values_color = (
((values - min_value) / (max_value - min_value) * 255).round().astype(int)
)
values_color_sorted = sorted(values_color)
values_list = values.astype(str).tolist()
values_sorted = sorted(values_list)
labels_list = labels.astype(str).tolist()
source = ColumnDataSource(
data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels_list)
)
hover = HoverTool(
tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")]
)
p = figure(plot_width=800, plot_height=800, tools=[hover])
p.circle(
"x",
"y",
size=10,
source=source,
fill_color=factor_cmap(
"label",
palette=[Pallete[id_] for id_ in values_color_sorted],
factors=values_sorted,
),
)
p.axis.visible = False
p.xgrid.grid_line_color = None
p.ygrid.grid_line_color = None
p.toolbar.logo = None
return p