wilmerags commited on
Commit
4b05a76
·
1 Parent(s): d447a02

test: Remove coercion to str from plotting function and joining keywords to get single string

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -98,8 +98,7 @@ def draw_interactive_scatter_plot(
98
  values_color_set = sorted(values_color)
99
  values_list = values.astype(str).tolist()
100
  values_set = sorted(values_list)
101
- labels_list = labels.astype(str).tolist()
102
- source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels_list))
103
  hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")])
104
  n_colors = len(set(values_color_set))
105
  if n_colors not in Pallete:
@@ -158,6 +157,7 @@ def generate_plot(
158
  most_descriptive = np.argmax(cluster_to_words_similarities)
159
  cluster_to_words_similarities = np.delete(cluster_to_words_similarities, most_descriptive)
160
  cluster_keyword[label].append(cluster_words[most_descriptive])
 
161
  encoded_labels_keywords = [cluster_keyword[encoded_label] for encoded_label in encoded_labels]
162
  embeddings_2d = get_tsne_embeddings(embeddings)
163
  plot = draw_interactive_scatter_plot(
 
98
  values_color_set = sorted(values_color)
99
  values_list = values.astype(str).tolist()
100
  values_set = sorted(values_list)
101
+ source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels))
 
102
  hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")])
103
  n_colors = len(set(values_color_set))
104
  if n_colors not in Pallete:
 
157
  most_descriptive = np.argmax(cluster_to_words_similarities)
158
  cluster_to_words_similarities = np.delete(cluster_to_words_similarities, most_descriptive)
159
  cluster_keyword[label].append(cluster_words[most_descriptive])
160
+ cluster_keyword[label] = ', '.join(cluster_keyword[label])
161
  encoded_labels_keywords = [cluster_keyword[encoded_label] for encoded_label in encoded_labels]
162
  embeddings_2d = get_tsne_embeddings(embeddings)
163
  plot = draw_interactive_scatter_plot(