Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -18,8 +18,15 @@ import networkx as nx
|
|
18 |
import plotly.graph_objects as go
|
19 |
import colorcet as cc
|
20 |
from matplotlib.colors import rgb2hex
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
st.set_page_config(layout="wide")
|
23 |
|
24 |
model_dir = "./models/sbert.net_models_sentence-transformers_clip-ViT-B-32-multilingual-v1"
|
25 |
|
@@ -461,6 +468,49 @@ def plot_graph(_G: nx.Graph, layout: str = "fdp"):
|
|
461 |
fig=go.Figure(data=data, layout=layout)
|
462 |
return fig
|
463 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
464 |
st.title("#ditaduranuncamais Data Explorer")
|
465 |
|
466 |
def check_password():
|
@@ -503,7 +553,8 @@ df = load_dataframe(dataset)
|
|
503 |
image_model = load_img_model()
|
504 |
text_model = load_txt_model()
|
505 |
|
506 |
-
menu_options = ["Data exploration", "Semantic search", "Hashtags", "Stats"]
|
|
|
507 |
st.sidebar.markdown('# Menu')
|
508 |
selected_menu_option = st.sidebar.radio("Select a page", menu_options)
|
509 |
|
@@ -634,7 +685,6 @@ elif selected_menu_option == "Hashtags":
|
|
634 |
if col2.button("Reset"):
|
635 |
st.session_state.dfx = df.copy() # Reset dfx to the original DataFrame
|
636 |
|
637 |
-
# df2['Hashtags'] = df2['Hashtags'].apply(lambda x: [item for item in x if not item == 'ditaduranuncamais'])
|
638 |
# Count the number of unique hashtags
|
639 |
hashtags = [item for sublist in st.session_state.dfx['Hashtags'].tolist() for item in sublist]
|
640 |
# Count the number of posts per hashtag
|
@@ -689,7 +739,6 @@ elif selected_menu_option == "Hashtags":
|
|
689 |
for node in community:
|
690 |
G_backbone.nodes[node]['community'] = i
|
691 |
|
692 |
-
|
693 |
# Sort community hashtags based on their weighted degree in the network
|
694 |
sorted_community_hashtags = [
|
695 |
[
|
@@ -716,6 +765,105 @@ elif selected_menu_option == "Hashtags":
|
|
716 |
st.markdown("### Hashtag Network Graph")
|
717 |
st.plotly_chart(plot_graph(G_backbone, layout="fdp")) # fdp is relatively slow, use 'sfdp' or 'neato' for faster but denser layouts
|
718 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
719 |
|
720 |
elif selected_menu_option == "Stats":
|
721 |
st.markdown("### Time Series Analysis")
|
|
|
18 |
import plotly.graph_objects as go
|
19 |
import colorcet as cc
|
20 |
from matplotlib.colors import rgb2hex
|
21 |
+
from sklearn.cluster import KMeans
|
22 |
+
from sklearn.decomposition import PCA
|
23 |
+
import hdbscan
|
24 |
+
import umap
|
25 |
+
import numpy as np
|
26 |
+
from bokeh.plotting import figure
|
27 |
+
from bokeh.models import ColumnDataSource
|
28 |
|
29 |
+
#st.set_page_config(layout="wide")
|
30 |
|
31 |
model_dir = "./models/sbert.net_models_sentence-transformers_clip-ViT-B-32-multilingual-v1"
|
32 |
|
|
|
468 |
fig=go.Figure(data=data, layout=layout)
|
469 |
return fig
|
470 |
|
471 |
+
@st.cache_data(show_spinner=True)
|
472 |
+
def cluster_embeddings(embeddings, clustering_algo='KMeans', dim_reduction='PCA', n_clusters=5, min_cluster_size=5, n_components=2, n_neighbors=15, min_dist=0.0, random_state=42, min_samples=5):
|
473 |
+
"""
|
474 |
+
A function to cluster embeddings.
|
475 |
+
|
476 |
+
Args:
|
477 |
+
embeddings (pd.Series): A series of numpy vectors.
|
478 |
+
clustering_algo (str): The clustering algorithm to use. Either 'KMeans' or 'HDBSCAN'.
|
479 |
+
dim_reduction (str): The dimensionality reduction method to use. Either 'PCA' or 'UMAP'.
|
480 |
+
n_clusters (int): The number of clusters for KMeans.
|
481 |
+
min_cluster_size (int): The minimum cluster size for HDBSCAN.
|
482 |
+
n_components (int): The number of components for the dimensionality reduction method.
|
483 |
+
n_neighbors (int): The number of neighbors for UMAP.
|
484 |
+
min_dist (float): The minimum distance for UMAP.
|
485 |
+
random_state (int): The seed used by the random number generator.
|
486 |
+
min_samples (int): The minimum number of samples for HDBSCAN.
|
487 |
+
|
488 |
+
Returns:
|
489 |
+
pd.Series: A series of cluster labels.
|
490 |
+
"""
|
491 |
+
|
492 |
+
# Dimensionality reduction
|
493 |
+
if dim_reduction == 'PCA':
|
494 |
+
reducer = PCA(n_components=n_components, random_state=random_state)
|
495 |
+
elif dim_reduction == 'UMAP':
|
496 |
+
reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components, random_state=random_state)
|
497 |
+
else:
|
498 |
+
raise ValueError('Invalid dimensionality reduction method')
|
499 |
+
|
500 |
+
reduced_embeddings = reducer.fit_transform(np.stack(embeddings))
|
501 |
+
|
502 |
+
# Clustering
|
503 |
+
if clustering_algo == 'KMeans':
|
504 |
+
clusterer = KMeans(n_clusters=n_clusters, random_state=random_state)
|
505 |
+
elif clustering_algo == 'HDBSCAN':
|
506 |
+
clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples)
|
507 |
+
else:
|
508 |
+
raise ValueError('Invalid clustering algorithm')
|
509 |
+
|
510 |
+
labels = clusterer.fit_predict(reduced_embeddings)
|
511 |
+
|
512 |
+
return labels, reduced_embeddings
|
513 |
+
|
514 |
st.title("#ditaduranuncamais Data Explorer")
|
515 |
|
516 |
def check_password():
|
|
|
553 |
image_model = load_img_model()
|
554 |
text_model = load_txt_model()
|
555 |
|
556 |
+
menu_options = ["Data exploration", "Semantic search", "Hashtags", "Clustering", "Stats"]
|
557 |
+
|
558 |
st.sidebar.markdown('# Menu')
|
559 |
selected_menu_option = st.sidebar.radio("Select a page", menu_options)
|
560 |
|
|
|
685 |
if col2.button("Reset"):
|
686 |
st.session_state.dfx = df.copy() # Reset dfx to the original DataFrame
|
687 |
|
|
|
688 |
# Count the number of unique hashtags
|
689 |
hashtags = [item for sublist in st.session_state.dfx['Hashtags'].tolist() for item in sublist]
|
690 |
# Count the number of posts per hashtag
|
|
|
739 |
for node in community:
|
740 |
G_backbone.nodes[node]['community'] = i
|
741 |
|
|
|
742 |
# Sort community hashtags based on their weighted degree in the network
|
743 |
sorted_community_hashtags = [
|
744 |
[
|
|
|
765 |
st.markdown("### Hashtag Network Graph")
|
766 |
st.plotly_chart(plot_graph(G_backbone, layout="fdp")) # fdp is relatively slow, use 'sfdp' or 'neato' for faster but denser layouts
|
767 |
|
768 |
+
elif selected_menu_option == "Clustering":
|
769 |
+
st.markdown("## Clustering")
|
770 |
+
st.markdown("Select the type of embeddings to cluster and the clustering algorithm and dimensionality reduction method to use in the sidebar. Then click run clustering. Clustering may take some time.")
|
771 |
+
st.sidebar.markdown("# Clustering Options")
|
772 |
+
type_embeddings = st.sidebar.selectbox("Type of embeddings to cluster", ["Text", "Image"])
|
773 |
+
clustering_algo = st.sidebar.selectbox("Clustering algorithm", ["HDBSCAN", "KMeans"])
|
774 |
+
dim_reduction = st.sidebar.selectbox("Dimensionality reduction method", ["UMAP", "PCA"])
|
775 |
+
if clustering_algo == "KMeans":
|
776 |
+
st.sidebar.markdown("### KMeans Options")
|
777 |
+
n_clusters = st.sidebar.slider("Number of clusters", 2, 20, 5)
|
778 |
+
min_cluster_size = None
|
779 |
+
min_samples = None
|
780 |
+
elif clustering_algo == "HDBSCAN":
|
781 |
+
st.sidebar.markdown("### HDBSCAN Options")
|
782 |
+
min_cluster_size = st.sidebar.slider("[Minimum cluster size](https://github.com/scikit-learn-contrib/hdbscan/blob/master/docs/parameter_selection.rst)", 2, 200, 5)
|
783 |
+
min_samples = st.sidebar.slider("Minimum samples", 2, 50, 5)
|
784 |
+
n_clusters = None
|
785 |
+
if dim_reduction == "UMAP":
|
786 |
+
st.sidebar.markdown("### UMAP Options")
|
787 |
+
n_components = st.sidebar.slider("Number of dimensions", 2, 80, 50)
|
788 |
+
n_neighbors = st.sidebar.slider("Number of neighbors", 2, 20, 15)
|
789 |
+
min_dist = st.sidebar.slider("Minimum distance", 0.0, 1.0, 0.0)
|
790 |
+
else:
|
791 |
+
st.sidebar.markdown("### PCA Options")
|
792 |
+
n_components = st.sidebar.slider("Number of dimensions", 2, 80, 2)
|
793 |
+
n_neighbors = None
|
794 |
+
min_dist = None
|
795 |
+
|
796 |
+
if st.sidebar.button('Run clustering'):
|
797 |
+
st.markdown("### Clustering Results")
|
798 |
+
if type_embeddings == "Text":
|
799 |
+
embeddings = dataset['txt_embs']
|
800 |
+
elif type_embeddings == "Image":
|
801 |
+
embeddings = dataset['img_embs']
|
802 |
+
|
803 |
+
# Cluster embeddings
|
804 |
+
labels, reduced_embeddings = cluster_embeddings(embeddings, clustering_algo=clustering_algo, dim_reduction=dim_reduction, n_clusters=n_clusters, min_cluster_size=min_cluster_size, n_components=n_components, n_neighbors=n_neighbors, min_dist=min_dist)
|
805 |
+
st.markdown(f"Clustering {type_embeddings} embeddings using {clustering_algo} with {dim_reduction} dimensionality reduction method resulting in **{len(set(labels))}** clusters.")
|
806 |
+
|
807 |
+
df_clustered = df.copy()
|
808 |
+
df_clustered['cluster'] = labels
|
809 |
+
df_clustered = df_clustered.set_index('cluster').reset_index()
|
810 |
+
st.dataframe(
|
811 |
+
data=filter_dataframe(df_clustered),
|
812 |
+
# use_container_width=True,
|
813 |
+
column_config={
|
814 |
+
"image": st.column_config.ImageColumn(
|
815 |
+
"Image", help="Instagram image"
|
816 |
+
),
|
817 |
+
"URL": st.column_config.LinkColumn(
|
818 |
+
"Link", help="Instagram link", width="small"
|
819 |
+
)
|
820 |
+
},
|
821 |
+
hide_index=True,
|
822 |
+
)
|
823 |
+
|
824 |
+
st.markdown("### Cluster Plot")
|
825 |
+
# Plot the scatter plot in plotly with the cluster labels as colors reduce further to 2 dimensions if n_components > 2
|
826 |
+
if n_components > 2:
|
827 |
+
reducer = umap.UMAP(n_components=2, random_state=42)
|
828 |
+
reduced_embeddings = reducer.fit_transform(reduced_embeddings)
|
829 |
+
# set the labels to be the cluster labels dynamically
|
830 |
+
|
831 |
+
# visualise with bokeh showing df_clustered['Description'] and df_clustered['image'] on hover
|
832 |
+
descriptions = df_clustered['Description'].tolist()
|
833 |
+
images = df_clustered['image'].tolist()
|
834 |
+
glasbey_colors = cc.glasbey_hv
|
835 |
+
color_dict = {n: rgb2hex(glasbey_colors[i % len(glasbey_colors)]) for i, n in enumerate(set(labels))}
|
836 |
+
colors = [color_dict[label] for label in labels]
|
837 |
+
|
838 |
+
source = ColumnDataSource(data=dict(
|
839 |
+
x=reduced_embeddings[:, 0],
|
840 |
+
y=reduced_embeddings[:, 1],
|
841 |
+
desc=descriptions,
|
842 |
+
imgs=images,
|
843 |
+
colors=colors
|
844 |
+
))
|
845 |
+
|
846 |
+
TOOLTIPS = """
|
847 |
+
<div>
|
848 |
+
<div>
|
849 |
+
<img
|
850 |
+
src="@imgs" height="100" alt="@imgs" width="100"
|
851 |
+
style="float: left; margin: 0px 15px 15px 0px;"
|
852 |
+
border="2"
|
853 |
+
></img>
|
854 |
+
</div>
|
855 |
+
<div>
|
856 |
+
<span style="font-size: 12px; font-weight: bold;">@desc</span>
|
857 |
+
</div>
|
858 |
+
</div>
|
859 |
+
"""
|
860 |
+
|
861 |
+
p = figure(width=800, height=800, tooltips=TOOLTIPS,
|
862 |
+
title="Mouse over the dots")
|
863 |
+
|
864 |
+
p.circle('x', 'y', size=10, source=source, color='colors', line_color=None)
|
865 |
+
st.bokeh_chart(p)
|
866 |
+
|
867 |
|
868 |
elif selected_menu_option == "Stats":
|
869 |
st.markdown("### Time Series Analysis")
|