rjadr commited on
Commit
3ffc79c
1 Parent(s): 72bddeb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -4
app.py CHANGED
@@ -18,8 +18,15 @@ import networkx as nx
18
  import plotly.graph_objects as go
19
  import colorcet as cc
20
  from matplotlib.colors import rgb2hex
 
 
 
 
 
 
 
21
 
22
- st.set_page_config(layout="wide")
23
 
24
  model_dir = "./models/sbert.net_models_sentence-transformers_clip-ViT-B-32-multilingual-v1"
25
 
@@ -461,6 +468,49 @@ def plot_graph(_G: nx.Graph, layout: str = "fdp"):
461
  fig=go.Figure(data=data, layout=layout)
462
  return fig
463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  st.title("#ditaduranuncamais Data Explorer")
465
 
466
  def check_password():
@@ -503,7 +553,8 @@ df = load_dataframe(dataset)
503
  image_model = load_img_model()
504
  text_model = load_txt_model()
505
 
506
- menu_options = ["Data exploration", "Semantic search", "Hashtags", "Stats"]
 
507
  st.sidebar.markdown('# Menu')
508
  selected_menu_option = st.sidebar.radio("Select a page", menu_options)
509
 
@@ -634,7 +685,6 @@ elif selected_menu_option == "Hashtags":
634
  if col2.button("Reset"):
635
  st.session_state.dfx = df.copy() # Reset dfx to the original DataFrame
636
 
637
- # df2['Hashtags'] = df2['Hashtags'].apply(lambda x: [item for item in x if not item == 'ditaduranuncamais'])
638
  # Count the number of unique hashtags
639
  hashtags = [item for sublist in st.session_state.dfx['Hashtags'].tolist() for item in sublist]
640
  # Count the number of posts per hashtag
@@ -689,7 +739,6 @@ elif selected_menu_option == "Hashtags":
689
  for node in community:
690
  G_backbone.nodes[node]['community'] = i
691
 
692
-
693
  # Sort community hashtags based on their weighted degree in the network
694
  sorted_community_hashtags = [
695
  [
@@ -716,6 +765,105 @@ elif selected_menu_option == "Hashtags":
716
  st.markdown("### Hashtag Network Graph")
717
  st.plotly_chart(plot_graph(G_backbone, layout="fdp")) # fdp is relatively slow, use 'sfdp' or 'neato' for faster but denser layouts
718
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
719
 
720
  elif selected_menu_option == "Stats":
721
  st.markdown("### Time Series Analysis")
 
18
  import plotly.graph_objects as go
19
  import colorcet as cc
20
  from matplotlib.colors import rgb2hex
21
+ from sklearn.cluster import KMeans
22
+ from sklearn.decomposition import PCA
23
+ import hdbscan
24
+ import umap
25
+ import numpy as np
26
+ from bokeh.plotting import figure
27
+ from bokeh.models import ColumnDataSource
28
 
29
+ #st.set_page_config(layout="wide")
30
 
31
  model_dir = "./models/sbert.net_models_sentence-transformers_clip-ViT-B-32-multilingual-v1"
32
 
 
468
  fig=go.Figure(data=data, layout=layout)
469
  return fig
470
 
471
+ @st.cache_data(show_spinner=True)
472
+ def cluster_embeddings(embeddings, clustering_algo='KMeans', dim_reduction='PCA', n_clusters=5, min_cluster_size=5, n_components=2, n_neighbors=15, min_dist=0.0, random_state=42, min_samples=5):
473
+ """
474
+ A function to cluster embeddings.
475
+
476
+ Args:
477
+ embeddings (pd.Series): A series of numpy vectors.
478
+ clustering_algo (str): The clustering algorithm to use. Either 'KMeans' or 'HDBSCAN'.
479
+ dim_reduction (str): The dimensionality reduction method to use. Either 'PCA' or 'UMAP'.
480
+ n_clusters (int): The number of clusters for KMeans.
481
+ min_cluster_size (int): The minimum cluster size for HDBSCAN.
482
+ n_components (int): The number of components for the dimensionality reduction method.
483
+ n_neighbors (int): The number of neighbors for UMAP.
484
+ min_dist (float): The minimum distance for UMAP.
485
+ random_state (int): The seed used by the random number generator.
486
+ min_samples (int): The minimum number of samples for HDBSCAN.
487
+
488
+ Returns:
489
+ pd.Series: A series of cluster labels.
490
+ """
491
+
492
+ # Dimensionality reduction
493
+ if dim_reduction == 'PCA':
494
+ reducer = PCA(n_components=n_components, random_state=random_state)
495
+ elif dim_reduction == 'UMAP':
496
+ reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components, random_state=random_state)
497
+ else:
498
+ raise ValueError('Invalid dimensionality reduction method')
499
+
500
+ reduced_embeddings = reducer.fit_transform(np.stack(embeddings))
501
+
502
+ # Clustering
503
+ if clustering_algo == 'KMeans':
504
+ clusterer = KMeans(n_clusters=n_clusters, random_state=random_state)
505
+ elif clustering_algo == 'HDBSCAN':
506
+ clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples)
507
+ else:
508
+ raise ValueError('Invalid clustering algorithm')
509
+
510
+ labels = clusterer.fit_predict(reduced_embeddings)
511
+
512
+ return labels, reduced_embeddings
513
+
514
  st.title("#ditaduranuncamais Data Explorer")
515
 
516
  def check_password():
 
553
  image_model = load_img_model()
554
  text_model = load_txt_model()
555
 
556
+ menu_options = ["Data exploration", "Semantic search", "Hashtags", "Clustering", "Stats"]
557
+
558
  st.sidebar.markdown('# Menu')
559
  selected_menu_option = st.sidebar.radio("Select a page", menu_options)
560
 
 
685
  if col2.button("Reset"):
686
  st.session_state.dfx = df.copy() # Reset dfx to the original DataFrame
687
 
 
688
  # Count the number of unique hashtags
689
  hashtags = [item for sublist in st.session_state.dfx['Hashtags'].tolist() for item in sublist]
690
  # Count the number of posts per hashtag
 
739
  for node in community:
740
  G_backbone.nodes[node]['community'] = i
741
 
 
742
  # Sort community hashtags based on their weighted degree in the network
743
  sorted_community_hashtags = [
744
  [
 
765
  st.markdown("### Hashtag Network Graph")
766
  st.plotly_chart(plot_graph(G_backbone, layout="fdp")) # fdp is relatively slow, use 'sfdp' or 'neato' for faster but denser layouts
767
 
768
+ elif selected_menu_option == "Clustering":
769
+ st.markdown("## Clustering")
770
+ st.markdown("Select the type of embeddings to cluster and the clustering algorithm and dimensionality reduction method to use in the sidebar. Then click run clustering. Clustering may take some time.")
771
+ st.sidebar.markdown("# Clustering Options")
772
+ type_embeddings = st.sidebar.selectbox("Type of embeddings to cluster", ["Text", "Image"])
773
+ clustering_algo = st.sidebar.selectbox("Clustering algorithm", ["HDBSCAN", "KMeans"])
774
+ dim_reduction = st.sidebar.selectbox("Dimensionality reduction method", ["UMAP", "PCA"])
775
+ if clustering_algo == "KMeans":
776
+ st.sidebar.markdown("### KMeans Options")
777
+ n_clusters = st.sidebar.slider("Number of clusters", 2, 20, 5)
778
+ min_cluster_size = None
779
+ min_samples = None
780
+ elif clustering_algo == "HDBSCAN":
781
+ st.sidebar.markdown("### HDBSCAN Options")
782
+ min_cluster_size = st.sidebar.slider("[Minimum cluster size](https://github.com/scikit-learn-contrib/hdbscan/blob/master/docs/parameter_selection.rst)", 2, 200, 5)
783
+ min_samples = st.sidebar.slider("Minimum samples", 2, 50, 5)
784
+ n_clusters = None
785
+ if dim_reduction == "UMAP":
786
+ st.sidebar.markdown("### UMAP Options")
787
+ n_components = st.sidebar.slider("Number of dimensions", 2, 80, 50)
788
+ n_neighbors = st.sidebar.slider("Number of neighbors", 2, 20, 15)
789
+ min_dist = st.sidebar.slider("Minimum distance", 0.0, 1.0, 0.0)
790
+ else:
791
+ st.sidebar.markdown("### PCA Options")
792
+ n_components = st.sidebar.slider("Number of dimensions", 2, 80, 2)
793
+ n_neighbors = None
794
+ min_dist = None
795
+
796
+ if st.sidebar.button('Run clustering'):
797
+ st.markdown("### Clustering Results")
798
+ if type_embeddings == "Text":
799
+ embeddings = dataset['txt_embs']
800
+ elif type_embeddings == "Image":
801
+ embeddings = dataset['img_embs']
802
+
803
+ # Cluster embeddings
804
+ labels, reduced_embeddings = cluster_embeddings(embeddings, clustering_algo=clustering_algo, dim_reduction=dim_reduction, n_clusters=n_clusters, min_cluster_size=min_cluster_size, n_components=n_components, n_neighbors=n_neighbors, min_dist=min_dist)
805
+ st.markdown(f"Clustering {type_embeddings} embeddings using {clustering_algo} with {dim_reduction} dimensionality reduction method resulting in **{len(set(labels))}** clusters.")
806
+
807
+ df_clustered = df.copy()
808
+ df_clustered['cluster'] = labels
809
+ df_clustered = df_clustered.set_index('cluster').reset_index()
810
+ st.dataframe(
811
+ data=filter_dataframe(df_clustered),
812
+ # use_container_width=True,
813
+ column_config={
814
+ "image": st.column_config.ImageColumn(
815
+ "Image", help="Instagram image"
816
+ ),
817
+ "URL": st.column_config.LinkColumn(
818
+ "Link", help="Instagram link", width="small"
819
+ )
820
+ },
821
+ hide_index=True,
822
+ )
823
+
824
+ st.markdown("### Cluster Plot")
825
+ # Plot the scatter plot in plotly with the cluster labels as colors reduce further to 2 dimensions if n_components > 2
826
+ if n_components > 2:
827
+ reducer = umap.UMAP(n_components=2, random_state=42)
828
+ reduced_embeddings = reducer.fit_transform(reduced_embeddings)
829
+ # set the labels to be the cluster labels dynamically
830
+
831
+ # visualise with bokeh showing df_clustered['Description'] and df_clustered['image'] on hover
832
+ descriptions = df_clustered['Description'].tolist()
833
+ images = df_clustered['image'].tolist()
834
+ glasbey_colors = cc.glasbey_hv
835
+ color_dict = {n: rgb2hex(glasbey_colors[i % len(glasbey_colors)]) for i, n in enumerate(set(labels))}
836
+ colors = [color_dict[label] for label in labels]
837
+
838
+ source = ColumnDataSource(data=dict(
839
+ x=reduced_embeddings[:, 0],
840
+ y=reduced_embeddings[:, 1],
841
+ desc=descriptions,
842
+ imgs=images,
843
+ colors=colors
844
+ ))
845
+
846
+ TOOLTIPS = """
847
+ <div>
848
+ <div>
849
+ <img
850
+ src="@imgs" height="100" alt="@imgs" width="100"
851
+ style="float: left; margin: 0px 15px 15px 0px;"
852
+ border="2"
853
+ ></img>
854
+ </div>
855
+ <div>
856
+ <span style="font-size: 12px; font-weight: bold;">@desc</span>
857
+ </div>
858
+ </div>
859
+ """
860
+
861
+ p = figure(width=800, height=800, tooltips=TOOLTIPS,
862
+ title="Mouse over the dots")
863
+
864
+ p.circle('x', 'y', size=10, source=source, color='colors', line_color=None)
865
+ st.bokeh_chart(p)
866
+
867
 
868
  elif selected_menu_option == "Stats":
869
  st.markdown("### Time Series Analysis")