wilmerags commited on
Commit
5b5b795
1 Parent(s): d24710c

feat: Improve colors robustness for n topics

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -7,7 +7,8 @@ import tweepy
7
  import hdbscan
8
 
9
  from bokeh.models import ColumnDataSource, HoverTool, Label
10
- from bokeh.palettes import Category10 as Pallete
 
11
  from bokeh.plotting import Figure, figure
12
  from bokeh.transform import factor_cmap
13
 
@@ -53,8 +54,11 @@ def draw_interactive_scatter_plot(
53
  labels_list = labels.astype(str).tolist()
54
  source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels_list))
55
  hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")])
 
 
 
56
  p = figure(plot_width=800, plot_height=800, tools=[hover], title='2D visualization of tweets', background_fill_color="#fafafa")
57
- colors = factor_cmap("label", palette=[Pallete[int(id_)] for id_ in values_color_set], factors=values_set)
58
  p.circle("x", "y", size=12, source=source, fill_alpha=0.4, line_color=colors, fill_color=colors)
59
  p.axis.visible = False
60
  p.xgrid.grid_line_dash = "dashed"
 
7
  import hdbscan
8
 
9
  from bokeh.models import ColumnDataSource, HoverTool, Label
10
+ from bokeh.palettes import Colorblind as Pallete
11
+ from bokeh.palettes import Set3 as AuxPallete
12
  from bokeh.plotting import Figure, figure
13
  from bokeh.transform import factor_cmap
14
 
 
54
  labels_list = labels.astype(str).tolist()
55
  source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels_list))
56
  hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")])
57
+ n_colors = len(set(values_color_set))
58
+ if n_colors not in Pallete:
59
+ Palette = AuxPallete
60
  p = figure(plot_width=800, plot_height=800, tools=[hover], title='2D visualization of tweets', background_fill_color="#fafafa")
61
+ colors = factor_cmap("label", palette=[Pallete[n_colors][int(id_)] for id_ in values_color_set], factors=values_set)
62
  p.circle("x", "y", size=12, source=source, fill_alpha=0.4, line_color=colors, fill_color=colors)
63
  p.axis.visible = False
64
  p.xgrid.grid_line_dash = "dashed"