wilmerags commited on
Commit
fe1cb2a
·
1 Parent(s): d556203

fix: Fix keyword selection to avoid duplicates

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -156,10 +156,11 @@ def generate_plot(
156
  cluster_words_embeddings = embed_text(cluster_words, model)
157
  cluster_to_words_similarities = util.dot_score(cluster_embeddings_avg, cluster_words_embeddings)
158
  cluster_to_words_similarities = cluster_to_words_similarities.numpy()
 
 
159
  while len(cluster_keyword[label]) < 3:
160
- most_descriptive = np.argmax(cluster_to_words_similarities)
161
- cluster_to_words_similarities = np.delete(cluster_to_words_similarities, most_descriptive)
162
- cluster_keyword[label].append(cluster_words[most_descriptive])
163
  cluster_keyword[label] = ', '.join(cluster_keyword[label])
164
  encoded_labels_keywords = [cluster_keyword[encoded_label] for encoded_label in encoded_labels]
165
  embeddings_2d = get_tsne_embeddings(embeddings)
 
156
  cluster_words_embeddings = embed_text(cluster_words, model)
157
  cluster_to_words_similarities = util.dot_score(cluster_embeddings_avg, cluster_words_embeddings)
158
  cluster_to_words_similarities = cluster_to_words_similarities.numpy()
159
+ cluster_to_words_similarities = [(word_ix, similarity) for word_ix, similarity in enumerate(cluster_to_words_similarities)]
160
+ cluster_to_words_similarities = sorted(cluster_to_words_similarities, key=lambda x: x[1], reverse=True)
161
  while len(cluster_keyword[label]) < 3:
162
+ most_descriptive = cluster_to_words_similarities.pop(0)
163
+ cluster_keyword[label].append(cluster_words[most_descriptive[0]])
 
164
  cluster_keyword[label] = ', '.join(cluster_keyword[label])
165
  encoded_labels_keywords = [cluster_keyword[encoded_label] for encoded_label in encoded_labels]
166
  embeddings_2d = get_tsne_embeddings(embeddings)