wilmerags commited on
Commit
7964bf4
·
1 Parent(s): 773a93e

feat: Add concrete clustering technique for visualization

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
 
5
  import streamlit as st
6
  import tweepy
 
7
 
8
  from bokeh.models import ColumnDataSource, HoverTool
9
  from bokeh.palettes import Cividis256 as Pallete
@@ -65,13 +66,15 @@ def draw_interactive_scatter_plot(
65
  # Up to here
66
  def generate_plot(
67
  df: List[str],
68
- labels: List[int],
69
  model: SentenceTransformer,
70
  ) -> Figure:
71
  with st.spinner(text="Embedding text..."):
72
  embeddings = embed_text(df, model)
73
  # encoded_labels = encode_labels(labels)
74
- encoded_labels = labels
 
 
 
75
  with st.spinner("Reducing dimensionality..."):
76
  embeddings_2d = get_tsne_embeddings(embeddings)
77
  plot = draw_interactive_scatter_plot(
@@ -112,8 +115,7 @@ if tw_user:
112
  if tw_sample > 0:
113
  tweets_response = client.get_users_tweets(usr.data.id, max_results=tw_sample)
114
  tweets_objs += tweets_response.data
115
- tweets_txt = [tweet.text for tweet in tweets_objs]
116
- labels = [0] * len(tweets_txt)
117
  # plot = generate_plot(df, text_column, label_column, sample, dimensionality_reduction_function, model)
118
- plot = generate_plot(tweets_txt, labels, model)
119
  st.bokeh_chart(plot)
 
4
 
5
  import streamlit as st
6
  import tweepy
7
+ import hdbscan
8
 
9
  from bokeh.models import ColumnDataSource, HoverTool
10
  from bokeh.palettes import Cividis256 as Pallete
 
66
  # Up to here
67
  def generate_plot(
68
  df: List[str],
 
69
  model: SentenceTransformer,
70
  ) -> Figure:
71
  with st.spinner(text="Embedding text..."):
72
  embeddings = embed_text(df, model)
73
  # encoded_labels = encode_labels(labels)
74
+ cluster = hdbscan.HDBSCAN(min_cluster_size=10,
75
+ metric='euclidean',
76
+ cluster_selection_method='eom').fit(umap_embeddings)
77
+ encoded_labels = cluster.labels_
78
  with st.spinner("Reducing dimensionality..."):
79
  embeddings_2d = get_tsne_embeddings(embeddings)
80
  plot = draw_interactive_scatter_plot(
 
115
  if tw_sample > 0:
116
  tweets_response = client.get_users_tweets(usr.data.id, max_results=tw_sample)
117
  tweets_objs += tweets_response.data
118
+ tweets_txt = [tweet.text for tweet in tweets_objs]
 
119
  # plot = generate_plot(df, text_column, label_column, sample, dimensionality_reduction_function, model)
120
+ plot = generate_plot(tweets_txt, model)
121
  st.bokeh_chart(plot)