Spaces:
Runtime error
Runtime error
import pandas as pd | |
import streamlit as st | |
import datasets | |
import plotly.express as px | |
from sentence_transformers import SentenceTransformer | |
from PIL import Image | |
import os | |
from pandas.api.types import ( | |
is_categorical_dtype, | |
is_datetime64_any_dtype, | |
is_numeric_dtype, | |
is_object_dtype, | |
) | |
import subprocess | |
from tempfile import NamedTemporaryFile | |
from itertools import combinations | |
import networkx as nx | |
import plotly.graph_objects as go | |
import colorcet as cc | |
from matplotlib.colors import rgb2hex | |
from sklearn.cluster import KMeans | |
from sklearn.decomposition import PCA | |
import hdbscan | |
import umap | |
import numpy as np | |
from bokeh.plotting import figure | |
from bokeh.models import ColumnDataSource | |
from datetime import datetime | |
#st.set_page_config(layout="wide") | |
model_dir = "./models/sbert.net_models_sentence-transformers_clip-ViT-B-32-multilingual-v1" | |
def download_models(): | |
# Directory doesn't exist, download and extract the model | |
subprocess.run(["mkdir", "models"]) | |
subprocess.run(["wget", "--no-check-certificate", "https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/v0.2/clip-ViT-B-32-multilingual-v1.zip"], check=True) | |
subprocess.run(["unzip", "-q", "clip-ViT-B-32-multilingual-v1.zip", "-d", model_dir], check=True) | |
token_ = os.getenv['token'] #st.secrets["token"] | |
def load_dataset(): | |
dataset = datasets.load_dataset('rjadr/ditaduranuncamais', split='train', use_auth_token=token_) | |
dataset.add_faiss_index(column="txt_embs") | |
dataset.add_faiss_index(column="img_embs") | |
dataset = dataset.remove_columns(['Post Created Date', 'Post Created Time','Like and View Counts Disabled','Link','Download URL','Views']) | |
return dataset | |
def load_dataframe(_dataset): | |
dataframe = _dataset.remove_columns(['txt_embs', 'img_embs']).to_pandas() | |
# Extract hashtags ith regex and convert to set | |
dataframe['Hashtags'] = dataframe.apply(lambda row: f"{row['Description']} {row['Image Text']}", axis=1) | |
dataframe['Hashtags'] = dataframe['Hashtags'].str.lower().str.findall(r'#(\w+)').apply(set) | |
# remove all hashtags that starts with 'throwback', 'thursday' or 'tbt' from the lists of hashtags per post | |
# dataframe['Hashtags'] = dataframe['Hashtags'].apply(lambda x: [item for item in x if not item.startswith('ditaduranuncamais')]) | |
# dataframe['Post Created'] = dataframe['Post Created'].dt.tz_convert('UTC') | |
dataframe = dataframe[['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name'] + [col for col in dataframe.columns if col not in ['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name']]] | |
return dataframe | |
def load_img_model(): | |
# We use the original clip-ViT-B-32 for encoding images | |
return SentenceTransformer('clip-ViT-B-32') | |
def load_txt_model(): | |
# Our text embedding model is aligned to the img_model and maps 50+ | |
# languages to the same vector space | |
return SentenceTransformer('./models/sbert.net_models_sentence-transformers_clip-ViT-B-32-multilingual-v1') | |
def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame: | |
""" | |
Adds a UI on top of a dataframe to let viewers filter columns | |
Args: | |
df (pd.DataFrame): Original dataframe | |
Returns: | |
pd.DataFrame: Filtered dataframe | |
""" | |
modify = st.checkbox("Add filters") | |
if not modify: | |
return df | |
df = df.copy() | |
# Try to convert datetimes into a standard format (datetime, no timezone) | |
for col in df.columns: | |
if is_object_dtype(df[col]): | |
try: | |
df[col] = pd.to_datetime(df[col]) | |
except Exception: | |
pass | |
if is_datetime64_any_dtype(df[col]): | |
df[col] = df[col].dt.tz_localize(None) | |
modification_container = st.container() | |
with modification_container: | |
to_filter_columns = st.multiselect("Filter dataframe on", df.columns) | |
for column in to_filter_columns: | |
left, right = st.columns((1, 20)) | |
left.write("↳") | |
# Treat columns with < 10 unique values as categorical | |
if is_categorical_dtype(df[column]) or df[column].nunique() < 10: | |
user_cat_input = right.multiselect( | |
f"Values for {column}", | |
df[column].unique(), | |
default=list(df[column].unique()), | |
) | |
df = df[df[column].isin(user_cat_input)] | |
elif is_numeric_dtype(df[column]): | |
_min = float(df[column].min()) | |
_max = float(df[column].max()) | |
step = (_max - _min) / 100 | |
user_num_input = right.slider( | |
f"Values for {column}", | |
_min, | |
_max, | |
(_min, _max), | |
step=step, | |
) | |
df = df[df[column].between(*user_num_input)] | |
elif is_datetime64_any_dtype(df[column]): | |
user_date_input = right.date_input( | |
f"Values for {column}", | |
value=( | |
df[column].min(), | |
df[column].max(), | |
), | |
) | |
if len(user_date_input) == 2: | |
user_date_input = tuple(map(pd.to_datetime, user_date_input)) | |
start_date, end_date = user_date_input | |
df = df.loc[df[column].between(start_date, end_date)] | |
else: | |
user_text_input = right.text_input( | |
f"Substring or regex in {column}", | |
) | |
if user_text_input: | |
df = df[df[column].str.contains(user_text_input)] | |
return df | |
def get_image_embs(image): | |
""" | |
Get image embeddings | |
Parameters: | |
uploaded_file (PIL.Image): Uploaded image file | |
Returns: | |
img_emb (np.array): Image embeddings | |
""" | |
img_emb = image_model.encode(Image.open(image)) | |
return img_emb | |
def get_text_embs(text): | |
""" | |
Get text embeddings | |
Parameters: | |
text (str): Text to encode | |
Returns: | |
text_emb (np.array): Text embeddings | |
""" | |
txt_emb = text_model.encode(text) | |
return txt_emb | |
def postprocess_results(scores, samples): | |
""" | |
Postprocess results to tuple of labels and scores | |
Parameters: | |
scores (np.array): Scores | |
samples (datasets.Dataset): Samples | |
Returns: | |
labels (list): List of tuples of PIL images and labels/scores | |
""" | |
samples_df = pd.DataFrame.from_dict(samples) | |
samples_df["score"] = scores | |
samples_df["score"] = (1 - (samples_df["score"] - samples_df["score"].min()) / ( | |
samples_df["score"].max() - samples_df["score"].min())) * 100 | |
samples_df["score"] = samples_df["score"].astype(int) | |
samples_df.reset_index(inplace=True, drop=True) | |
samples_df = samples_df[['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name'] + [col for col in samples_df.columns if col not in ['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name']]] | |
return samples_df.drop(columns=['txt_embs', 'img_embs']) | |
def text_to_text(text, k=5): | |
""" | |
Text to text | |
Parameters: | |
text (str): Input text | |
k (int): Number of top results to return | |
Returns: | |
results (list): List of tuples of PIL images and labels/scores | |
""" | |
text_emb = get_text_embs(text) | |
scores, samples = dataset.get_nearest_examples('txt_embs', text_emb, k=k) | |
return postprocess_results(scores, samples) | |
def image_to_text(image, k=5): | |
""" | |
Image to text | |
Parameters: | |
image (str): Temp filepath to image | |
k (int): Number of top results to return | |
Returns: | |
results (list): List of tuples of PIL images and labels/scores | |
""" | |
img_emb = get_image_embs(image.name) | |
scores, samples = dataset.get_nearest_examples('txt_embs', img_emb, k=k) | |
return postprocess_results(scores, samples) | |
def text_to_image(text, k=5): | |
""" | |
Text to image | |
Parameters: | |
text (str): Input text | |
k (int): Number of top results to return | |
Returns: | |
results (list): List of tuples of PIL images and labels/scores | |
""" | |
text_emb = get_text_embs(text) | |
scores, samples = dataset.get_nearest_examples('img_embs', text_emb, k=k) | |
return postprocess_results(scores, samples) | |
def image_to_image(image, k=5): | |
""" | |
Image to image | |
Parameters: | |
image (str): Temp filepath to image | |
k (int): Number of top results to return | |
Returns: | |
results (list): List of tuples of PIL images and labels/scores | |
""" | |
img_emb = get_image_embs(image.name) | |
scores, samples = dataset.get_nearest_examples('img_embs', img_emb, k=k) | |
return postprocess_results(scores, samples) | |
def disparity_filter(g: nx.Graph, weight: str = 'weight', alpha: float = 0.05) -> nx.Graph: | |
""" | |
Computes the backbone of the input graph using the disparity filter algorithm. | |
The algorithm is proposed in: | |
M. A. Serrano, M. Boguna, and A. Vespignani, | |
"Extracting the Multiscale Backbone of Complex Weighted Networks", | |
PNAS, 106(16), pp 6483--6488 (2009). | |
DOI: 10.1073/pnas.0808904106 | |
Implementation taken from https://groups.google.com/g/networkx-discuss/c/bCuHZ3qQ2po/m/QvUUJqOYDbIJ | |
Parameters | |
---------- | |
g : NetworkX graph | |
The input graph. | |
weight : str, optional (default='weight') | |
The name of the edge attribute to use as weight. | |
alpha : float, optional (default=0.05) | |
The statistical significance level for the disparity filter (p-value). | |
Returns | |
------- | |
backbone_graph : NetworkX graph | |
The backbone graph. | |
""" | |
# Create an empty graph for the backbone | |
backbone_graph = nx.Graph() | |
# Iterate over all nodes in the input graph | |
for node in g: | |
# Get the degree of the node (number of edges connected to the node) | |
k_n = len(g[node]) | |
# Only proceed if the node has more than one connection | |
if k_n > 1: | |
# Calculate the sum of weights of edges connected to the node | |
sum_w = sum(g[node][neighbor][weight] for neighbor in g[node]) | |
# Iterate over all neighbors of the node | |
for neighbor in g[node]: | |
# Get the weight of the edge between the node and its neighbor | |
edge_weight = g[node][neighbor][weight] | |
# Calculate the proportion of the total weight that this edge represents | |
pij = float(edge_weight) / sum_w | |
# Perform the disparity filter test. If it passes, the edge is considered significant and is added to the backbone | |
if (1 - pij) ** (k_n - 1) < alpha: | |
backbone_graph.add_edge(node, neighbor, weight=edge_weight) | |
# Return the backbone graph | |
return backbone_graph | |
st.cache_data(show_spinner=True) | |
def assign_community_colors(G: nx.Graph, attr: str = 'community') -> dict: | |
""" | |
Assigns a unique color to each community in the input graph. | |
Parameters | |
---------- | |
G : nx.Graph | |
The input graph. | |
attr : str, optional | |
The node attribute of the community names or indexes (default is 'community'). | |
Returns | |
------- | |
dict | |
A dictionary mapping each community to a unique color. | |
""" | |
glasbey_colors = cc.glasbey_hv | |
communities_ = set(nx.get_node_attributes(G, attr).values()) | |
return {community: rgb2hex(glasbey_colors[i % len(glasbey_colors)]) for i, community in enumerate(communities_)} | |
st.cache_data(show_spinner=True) | |
def generate_hover_text(G: nx.Graph, attr: str = 'community') -> list: | |
""" | |
Generates hover text for each node in the input graph. | |
Parameters | |
---------- | |
G : nx.Graph | |
The input graph. | |
attr : str, optional | |
The node attribute of the community names or indexes (default is 'community'). | |
Returns | |
------- | |
list | |
A list of strings containing the hover text for each node. | |
""" | |
return [f"Node: {str(node)}<br>Community: {G.nodes[node][attr] + 1}<br># of connections: {len(adjacencies)}" for node, adjacencies in G.adjacency()] | |
st.cache_data(show_spinner=True) | |
def calculate_node_sizes(G: nx.Graph) -> list: | |
""" | |
Calculates the size of each node in the input graph based on its degree. | |
Parameters | |
---------- | |
G : nx.Graph | |
The input graph. | |
Returns | |
------- | |
list | |
A list of node sizes. | |
""" | |
degrees = dict(G.degree()) | |
max_degree = max(deg for node, deg in degrees.items()) | |
return [10 + 20 * (degrees[node] / max_degree) for node in G.nodes()] | |
def plot_graph(_G: nx.Graph, layout: str = "fdp", community_names_lookup: dict = None): | |
""" | |
Plots a network graph with communities. | |
Parameters | |
---------- | |
G : nx.Graph | |
The input graph. | |
layout : str, optional | |
The layout algorithm to use (default is "fdp"). | |
""" | |
pos = nx.spring_layout(G_backbone, dim=3, seed=779) | |
community_colors = assign_community_colors(_G) | |
node_colors = [community_colors[_G.nodes[n]['community']] for n in _G.nodes] | |
edge_trace = go.Scatter(x=[item for sublist in [[pos[edge[0]][0], pos[edge[1]][0], None] for edge in _G.edges()] for item in sublist], | |
y=[item for sublist in [[pos[edge[0]][1], pos[edge[1]][1], None] for edge in _G.edges()] for item in sublist], | |
line=dict(width=0.5, color='#888'), | |
hoverinfo='none', | |
mode='lines') | |
node_trace = go.Scatter(x=[pos[n][0] for n in _G.nodes()], | |
y=[pos[n][1] for n in _G.nodes()], | |
mode='markers', | |
hoverinfo='text', | |
marker=dict(color=node_colors, size=10, line_width=2)) | |
node_trace.text = generate_hover_text(_G) | |
node_trace.marker.size = calculate_node_sizes(_G) | |
fig = go.Figure(data=[edge_trace, node_trace], | |
layout=go.Layout(title='Network graph with communities', | |
titlefont=dict(size=16), | |
showlegend=False, | |
hovermode='closest', | |
margin=dict(b=20,l=5,r=5,t=40), | |
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
height=800)) | |
# Extract node positions | |
Xn=[pos[k][0] for k in G_backbone.nodes()] # x-coordinates of nodes | |
Yn=[pos[k][1] for k in G_backbone.nodes()] # y-coordinates | |
Zn=[pos[k][2] for k in G_backbone.nodes()] # z-coordinates | |
# Extract edge positions | |
Xe=[] | |
Ye=[] | |
Ze=[] | |
for e in G_backbone.edges(): | |
Xe+=[pos[e[0]][0],pos[e[1]][0], None] # x-coordinates of edge ends | |
Ye+=[pos[e[0]][1],pos[e[1]][1], None] | |
Ze+=[pos[e[0]][2],pos[e[1]][2], None] | |
# Define traces for plotly | |
trace1=go.Scatter3d(x=Xe, | |
y=Ye, | |
z=Ze, | |
mode='lines', | |
line=dict(color='rgb(125,125,125)', width=1), | |
hoverinfo='none' | |
) | |
# Map community numbers to names | |
community_names = {i: community_names_lookup[f"Community {i+1}"] for i in range(len(communities))} | |
# Create hover text | |
hover_text = [f"{node} ({community_names[G_backbone.nodes[node]['community']]})" for node in G_backbone.nodes()] | |
trace2=go.Scatter3d(x=Xn, | |
y=Yn, | |
z=Zn, | |
mode='markers', | |
name='actors', | |
marker=dict(symbol='circle', | |
size=7, | |
color=node_colors, # pass hex colors | |
line=dict(color='rgb(50,50,50)', width=0.2) | |
), | |
text=hover_text, # Use community names as hover text | |
hoverinfo='text' | |
) | |
axis=dict(showbackground=False, | |
showline=False, | |
zeroline=False, | |
showgrid=False, | |
showticklabels=False, | |
title='' | |
) | |
layout = go.Layout( | |
title="3D Network Graph", | |
width=1000, | |
height=1000, | |
showlegend=False, | |
scene=dict( | |
xaxis=dict(axis), | |
yaxis=dict(axis), | |
zaxis=dict(axis), | |
), | |
margin=dict( | |
t=100 | |
), | |
hovermode='closest', | |
) | |
data=[trace1, trace2] | |
fig=go.Figure(data=data, layout=layout) | |
return fig | |
def cluster_embeddings(embeddings, clustering_algo='KMeans', dim_reduction='PCA', n_clusters=5, min_cluster_size=5, n_components=2, n_neighbors=15, min_dist=0.0, random_state=42, min_samples=5): | |
""" | |
A function to cluster embeddings. | |
Args: | |
embeddings (pd.Series): A series of numpy vectors. | |
clustering_algo (str): The clustering algorithm to use. Either 'KMeans' or 'HDBSCAN'. | |
dim_reduction (str): The dimensionality reduction method to use. Either 'PCA' or 'UMAP'. | |
n_clusters (int): The number of clusters for KMeans. | |
min_cluster_size (int): The minimum cluster size for HDBSCAN. | |
n_components (int): The number of components for the dimensionality reduction method. | |
n_neighbors (int): The number of neighbors for UMAP. | |
min_dist (float): The minimum distance for UMAP. | |
random_state (int): The seed used by the random number generator. | |
min_samples (int): The minimum number of samples for HDBSCAN. | |
Returns: | |
pd.Series: A series of cluster labels. | |
""" | |
# Dimensionality reduction | |
if dim_reduction == 'PCA': | |
reducer = PCA(n_components=n_components, random_state=random_state) | |
elif dim_reduction == 'UMAP': | |
reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components, random_state=random_state) | |
else: | |
raise ValueError('Invalid dimensionality reduction method') | |
reduced_embeddings = reducer.fit_transform(np.stack(embeddings)) | |
# Clustering | |
if clustering_algo == 'KMeans': | |
clusterer = KMeans(n_clusters=n_clusters, random_state=random_state) | |
elif clustering_algo == 'HDBSCAN': | |
clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples) | |
else: | |
raise ValueError('Invalid clustering algorithm') | |
labels = clusterer.fit_predict(reduced_embeddings) | |
return labels, reduced_embeddings | |
st.title("#ditaduranuncamais Data Explorer") | |
def check_password(): | |
"""Returns `True` if the user had the correct password.""" | |
def password_entered(): | |
"""Checks whether a password entered by the user is correct.""" | |
if st.session_state["password"] == st.secrets["password"]: | |
st.session_state["password_correct"] = True | |
del st.session_state["password"] # don't store password | |
else: | |
st.session_state["password_correct"] = False | |
if "password_correct" not in st.session_state: | |
# First run, show input for password. | |
st.text_input( | |
"Password", type="password", on_change=password_entered, key="password" | |
) | |
return False | |
elif not st.session_state["password_correct"]: | |
# Password not correct, show input + error. | |
st.text_input( | |
"Password", type="password", on_change=password_entered, key="password" | |
) | |
st.error("😕 Password incorrect") | |
return False | |
else: | |
# Password correct. | |
return True | |
if not check_password(): | |
st.stop() | |
# Check if the directory exists | |
if not os.path.exists(model_dir): | |
download_models() | |
dataset = load_dataset() | |
df = load_dataframe(dataset) | |
image_model = load_img_model() | |
text_model = load_txt_model() | |
menu_options = ["Data exploration", "Semantic search", "Hashtags", "Clustering", "Stats"] | |
st.sidebar.markdown('# Menu') | |
selected_menu_option = st.sidebar.radio("Select a page", menu_options) | |
if selected_menu_option == "Data exploration": | |
st.dataframe( | |
data=filter_dataframe(df), | |
# use_container_width=True, | |
column_config={ | |
"image": st.column_config.ImageColumn( | |
"Image", help="Instagram image" | |
), | |
"URL": st.column_config.LinkColumn( | |
"Link", help="Instagram link", width="small" | |
) | |
}, | |
hide_index=True, | |
) | |
elif selected_menu_option == "Semantic search": | |
tabs = ["Text to Text", "Text to Image", "Image to Image", "Image to Text"] | |
selected_tab = st.sidebar.radio("Select a search type", tabs) | |
if selected_tab == "Text to Text": | |
st.markdown('## Text to text search') | |
text_to_text_input = st.text_input("Enter text") | |
text_to_text_k_top = st.slider("Number of results", 1, 500, 20) | |
if st.button("Search"): | |
if not text_to_text_input: | |
st.warning("Please enter text") | |
else: | |
st.dataframe( | |
data=text_to_text(text_to_text_input, text_to_text_k_top), | |
column_config={ | |
"image": st.column_config.ImageColumn( | |
"Image", help="Instagram image" | |
), | |
"URL": st.column_config.LinkColumn( | |
"Link", help="Instagram link", width="small" | |
) | |
}, | |
hide_index=True, | |
) | |
elif selected_tab == "Text to Image": | |
st.markdown('## Text to image search') | |
text_to_image_input = st.text_input("Enter text") | |
text_to_image_k_top = st.slider("Number of results", 1, 500, 20) | |
if st.button("Search"): | |
if not text_to_image_input: | |
st.warning("Please enter some text") | |
else: | |
st.dataframe( | |
data=text_to_image(text_to_image_input, text_to_image_k_top), | |
column_config={ | |
"image": st.column_config.ImageColumn( | |
"Image", help="Instagram image" | |
), | |
"URL": st.column_config.LinkColumn( | |
"Link", help="Instagram link", width="small" | |
) | |
}, | |
hide_index=True, | |
) | |
elif selected_tab == "Image to Image": | |
st.markdown('## Image to image search') | |
image_to_image_k_top = st.slider("Number of results", 1, 500, 20) | |
image_to_image_input = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
temp_file = NamedTemporaryFile(delete=False) | |
if st.button("Search"): | |
if not image_to_image_input: | |
st.warning("Please upload an image") | |
else: | |
temp_file.write(image_to_image_input.getvalue()) | |
st.dataframe( | |
data=image_to_image(temp_file, image_to_image_k_top), | |
column_config={ | |
"image": st.column_config.ImageColumn( | |
"Image", help="Instagram image" | |
), | |
"URL": st.column_config.LinkColumn( | |
"Link", help="Instagram link", width="small" | |
) | |
}, | |
hide_index=True, | |
) | |
elif selected_tab == "Image to Text": | |
st.markdown('## Image to text search') | |
image_to_text_k_top = st.slider("Number of results", 1, 500, 20) | |
image_to_text_input = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
temp_file = NamedTemporaryFile(delete=False) | |
if st.button("Search"): | |
if not image_to_text_input: | |
st.warning("Please upload an image") | |
else: | |
temp_file.write(image_to_text_input.getvalue()) | |
st.dataframe( | |
data=image_to_text(temp_file, image_to_text_k_top), | |
column_config={ | |
"image": st.column_config.ImageColumn( | |
"Image", help="Instagram image" | |
), | |
"URL": st.column_config.LinkColumn( | |
"Link", help="Instagram link", width="small" | |
) | |
}, | |
hide_index=True, | |
) | |
elif selected_menu_option == "Hashtags": | |
if 'dfx' not in st.session_state: | |
st.session_state.dfx = df.copy() # Make a copy of dfx | |
# Get a list of all unique hashtags in the DataFrame | |
all_hashtags = list(set([item for sublist in st.session_state.dfx['Hashtags'].tolist() for item in sublist])) | |
st.sidebar.markdown('# Hashtag co-occurrence analysis options') | |
# Let users select hashtags to remove | |
hashtags_to_remove = st.sidebar.multiselect("Hashtags to remove", all_hashtags) | |
col1, col2 = st.sidebar.columns(2) | |
# Add a button to trigger the removal operation | |
if col1.button("Remove hashtags"): | |
# If dfx does not exist in session state, create it | |
st.session_state.dfx['Hashtags'] = st.session_state.dfx['Hashtags'].apply(lambda x: [item for item in x if item not in hashtags_to_remove]) | |
# Add a reset button | |
if col2.button("Reset"): | |
st.session_state.dfx = df.copy() # Reset dfx to the original DataFrame | |
# Count the number of unique hashtags | |
hashtags = [item for sublist in st.session_state.dfx['Hashtags'].tolist() for item in sublist] | |
# Count the number of posts per hashtag | |
hashtag_freq = st.session_state.dfx.explode('Hashtags').groupby('Hashtags').size().reset_index(name='counts') | |
# Sort the hashtags by frequency | |
hashtag_freq = hashtag_freq.sort_values(by='counts', ascending=False) | |
# Make the scatter plot | |
hashtags_fig = px.scatter(hashtag_freq, x='Hashtags', y='counts', log_y=True, # Set log_y to True to make the plot more readable on a log scale | |
labels={'Hashtags': 'Hashtags', 'counts': 'Frequency'}, | |
title='Frequency of hashtags in #throwbackthursday posts on Instagram', | |
height=600) # Set the height to 600 pixels | |
st.markdown("### Hashtag Frequency Distribution") | |
st.markdown('Here we apply hashtag co-occurence analysis for mnemonic community detection. This detects communities through creating a network of hashtag pairs (which hashtags are used together in which posts) and then applying community detection algorithms on this network.') | |
st.plotly_chart(hashtags_fig) | |
weight_option = st.sidebar.radio( | |
'Select weight definition', | |
('Number of users that use the hashtag pairs', 'Total number of occurrences') | |
) | |
hashtag_user_pairs = [(tuple(sorted(combination)), userid) for hashtags, userid in zip(st.session_state.dfx['Hashtags'], st.session_state.dfx['User Name']) for combination in combinations(hashtags, r=2)] | |
# Create a DataFrame with columns 'hashtag_pair' and 'userid' | |
hashtag_user_df = pd.DataFrame(hashtag_user_pairs, columns=['hashtag_pair', 'User Name']) | |
if weight_option == 'Number of users that use the hashtag pairs': | |
# Group by 'hashtag_pair' and count the number of unique 'userid's | |
hashtag_user_df = hashtag_user_df.groupby('hashtag_pair').agg({'User Name': 'nunique'}).reset_index() | |
elif weight_option == 'Total number of occurrences': | |
# Group by 'hashtag_pair' and count the total number of occurrences | |
hashtag_user_df = hashtag_user_df.groupby('hashtag_pair').size().reset_index(name='User Name') | |
# Make edge_list from hashtag_user_df with columns 'hashtag1', 'hashtag2', and 'weight' | |
edge_list = hashtag_user_df.rename(columns={'hashtag_pair': 'hashtag1', 'User Name': 'weight'}) | |
edge_list[['hashtag1', 'hashtag2']] = pd.DataFrame(edge_list['hashtag1'].tolist(), index=edge_list.index) | |
edge_list = edge_list[['hashtag1', 'hashtag2', 'weight']] | |
st.markdown("### Edge List of Hashtag Pairs") | |
# Create the graph using the unique users as adge attributes | |
G = nx.from_pandas_edgelist(edge_list, 'hashtag1', 'hashtag2', 'weight') | |
G_backbone = disparity_filter(G, weight='weight', alpha=0.05) | |
st.markdown(f'Number of nodes {len(G_backbone.nodes)}') | |
st.markdown(f'Number of edges {len(G_backbone.edges)}') | |
st.dataframe(edge_list.sort_values(by='weight', ascending=False).head(10).style.set_caption("Edge list of hashtag pairs with the highest weight")) | |
# Create louvain communities | |
communities = nx.community.louvain_communities(G_backbone, weight='weight', seed=1234) | |
communities = list(communities) | |
# Sort communities by size | |
communities.sort(key=len, reverse=True) | |
for i, community in enumerate(communities): | |
for node in community: | |
G_backbone.nodes[node]['community'] = i | |
# Sort community hashtags based on their weighted degree in the network | |
sorted_community_hashtags = [ | |
[ | |
hashtag | |
for hashtag, degree in sorted( | |
((h, G.degree(h, weight='weight')) for h in community), | |
key=lambda x: x[1], | |
reverse=True | |
) | |
] | |
for community in communities | |
] | |
# Convert the sorted_community_hashtags list into a DataFrame and transpose it | |
sorted_community_hashtags = pd.DataFrame(sorted_community_hashtags).T | |
# Rename the columns of sorted_community_hashtags DataFrame | |
sorted_community_hashtags.columns = [f'Community {i+1}' for i in range(len(sorted_community_hashtags.columns))] | |
st.markdown("### Hashtag Communities") | |
st.markdown(f'There are {len(communities)} communities in the graph.') | |
st.dataframe(sorted_community_hashtags) | |
# add a st.data_editor with Community 1, etc as index and a column "community names" that sets Community 1 etc as default value | |
st.markdown("### Community Names") | |
st.markdown("Edit the names of the communities in the table below so they show up in the visualisations.") | |
df_community_names = pd.DataFrame(sorted_community_hashtags.columns, columns=['community_names'], index=sorted_community_hashtags.columns) | |
df_community_names = st.data_editor(df_community_names) | |
# download the edited df_community_names as csv | |
st.download_button( | |
label="Download community names as csv", | |
data=df_community_names.to_csv().encode("utf-8"), | |
file_name="community_names.csv", | |
mime="text/csv", | |
) | |
#create dict with community names | |
community_names_lookup = df_community_names['community_names'].to_dict() | |
# implement time series analysis of size of communities over time using resample_dict | |
st.markdown("### Community Size Over Time") | |
st.markdown("Select communites to see their size over time.") | |
# selected_communities = st.multiselect('Select Communities', [f'Community {i+1}' for i in range(len(communities))], default=[f'Community {i+1}' for i in range(len(communities))]) | |
selected_communities = st.multiselect('Select Communities', community_names_lookup.values(), default=community_names_lookup.values()) | |
# Dropdown to select time resampling | |
resample_dict = { | |
'Day': 'D', | |
'Three Days': '3D', | |
'Week': 'W', | |
'Two Weeks': '2W', | |
'Month': 'M', | |
'Quarter': 'Q', | |
'Year': 'Y' | |
} | |
# Dropdown to select time resampling | |
resample_time = st.selectbox('Select Time Resampling', list(resample_dict.keys()), index=4) | |
df_communities = st.session_state.dfx.copy() | |
def community_dict(communities): | |
community_dict = {} | |
for i, community in enumerate(communities): | |
for node in community: | |
community_dict[node] = community_names_lookup[f'Community {i+1}'] | |
return community_dict | |
community_dict = community_dict(communities) | |
df_communities['Communities'] = df_communities['Hashtags'].apply(lambda x: [community_dict[tag] for tag in x if tag in community_dict.keys()]) | |
df_communities = df_communities[['Post Created', 'Communities']].explode('Communities') | |
df_communities = df_communities.dropna(subset=['Communities']) | |
# Slider for date range selection | |
min_date = df_communities['Post Created'].min().date() | |
max_date = df_communities['Post Created'].max().date() | |
date_range = st.slider('Select Date Range', min_value=min_date, max_value=max_date, value=(min_date, max_date)) | |
# Filter df_communities by the selected date range | |
df_communities = df_communities[(df_communities['Post Created'].dt.date >= date_range[0]) & (df_communities['Post Created'].dt.date <= date_range[1])] | |
# Count the number of posts per community per resample_time | |
df_communities['Post Created'] = df_communities['Post Created'].dt.to_period(resample_dict[resample_time]) | |
df_community_sizes = df_communities.groupby(['Post Created', 'Communities']).size().unstack(fill_value=0) | |
df_community_sizes.index = df_community_sizes.index.to_timestamp() | |
# Filter the DataFrame to include only the selected communities | |
df_community_sizes = df_community_sizes[selected_communities] | |
st.plotly_chart(px.line(df_community_sizes, title='Community Size Over Time', labels={'value': 'Number of posts', 'index': 'Date', 'variable': 'Community'})) | |
st.markdown("### Hashtag Network Graph") | |
st.plotly_chart(plot_graph(G_backbone, layout="fdp", community_names_lookup=community_names_lookup)) # fdp is relatively slow, use 'sfdp' or 'neato' for faster but denser layouts | |
elif selected_menu_option == "Clustering": | |
st.markdown("## Clustering") | |
st.markdown("Select the type of embeddings to cluster and the clustering algorithm and dimensionality reduction method to use in the sidebar. Then click run clustering. Clustering may take some time.") | |
st.sidebar.markdown("# Clustering Options") | |
type_embeddings = st.sidebar.selectbox("Type of embeddings to cluster", ["Text", "Image"]) | |
clustering_algo = st.sidebar.selectbox("Clustering algorithm", ["HDBSCAN", "KMeans"]) | |
dim_reduction = st.sidebar.selectbox("Dimensionality reduction method", ["UMAP", "PCA"]) | |
if clustering_algo == "KMeans": | |
st.sidebar.markdown("### KMeans Options") | |
n_clusters = st.sidebar.slider("Number of clusters", 2, 20, 5) | |
min_cluster_size = None | |
min_samples = None | |
elif clustering_algo == "HDBSCAN": | |
st.sidebar.markdown("### HDBSCAN Options") | |
min_cluster_size = st.sidebar.slider("[Minimum cluster size](https://hdbscan.readthedocs.io/en/latest/parameter_selection.html#selecting-min-cluster-size)", 2, 200, 5) | |
min_samples = st.sidebar.slider("[Minimum samples](https://hdbscan.readthedocs.io/en/latest/parameter_selection.html#selecting-min-samples)", 2, 50, 5) | |
n_clusters = None | |
if dim_reduction == "UMAP": | |
st.sidebar.markdown("### UMAP Options") | |
n_components = st.sidebar.slider("[Number of components](https://umap-learn.readthedocs.io/en/latest/parameters.html#n-components)", 2, 80, 50) | |
n_neighbors = st.sidebar.slider("[Number of neighbors](https://umap-learn.readthedocs.io/en/latest/parameters.html#n-neighbors)", 2, 20, 15) | |
min_dist = st.sidebar.slider("[Minimum distance](https://umap-learn.readthedocs.io/en/latest/parameters.html#min-dist)", 0.0, 1.0, 0.0) | |
else: | |
st.sidebar.markdown("### PCA Options") | |
n_components = st.sidebar.slider("[Number of components](https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html)", 2, 80, 2) | |
n_neighbors = None | |
min_dist = None | |
st.markdown("### Clustering Results") | |
if type_embeddings == "Text": | |
embeddings = dataset['txt_embs'] | |
elif type_embeddings == "Image": | |
embeddings = dataset['img_embs'] | |
# Cluster embeddings | |
labels, reduced_embeddings = cluster_embeddings(embeddings, clustering_algo=clustering_algo, dim_reduction=dim_reduction, n_clusters=n_clusters, min_cluster_size=min_cluster_size, n_components=n_components, n_neighbors=n_neighbors, min_dist=min_dist) | |
st.markdown(f"Clustering {type_embeddings} embeddings using {clustering_algo} with {dim_reduction} dimensionality reduction method resulting in **{len(set(labels))}** clusters.") | |
df_clustered = df.copy() | |
df_clustered['cluster'] = labels | |
df_clustered = df_clustered.set_index('cluster').reset_index() | |
st.dataframe( | |
data=filter_dataframe(df_clustered), | |
# use_container_width=True, | |
column_config={ | |
"image": st.column_config.ImageColumn( | |
"Image", help="Instagram image" | |
), | |
"URL": st.column_config.LinkColumn( | |
"Link", help="Instagram link", width="small" | |
) | |
}, | |
hide_index=True, | |
) | |
st.download_button( | |
"Download dataset with labels", | |
df_clustered.to_csv(index=False).encode('utf-8'), | |
f'ditaduranuncamais_{datetime.now().strftime("%Y%m%d-%H%M%S")}.csv', | |
"text/csv", | |
key='download-csv' | |
) | |
st.markdown("### Cluster Plot") | |
# Plot the scatter plot in plotly with the cluster labels as colors reduce further to 2 dimensions if n_components > 2 | |
if n_components > 2: | |
reducer = umap.UMAP(n_components=2, random_state=42) | |
reduced_embeddings = reducer.fit_transform(reduced_embeddings) | |
# set the labels to be the cluster labels dynamically | |
# visualise with bokeh showing df_clustered['Description'] and df_clustered['image'] on hover | |
descriptions = df_clustered['Description'].tolist() | |
images = df_clustered['image'].tolist() | |
glasbey_colors = cc.glasbey_hv | |
color_dict = {n: rgb2hex(glasbey_colors[i % len(glasbey_colors)]) for i, n in enumerate(set(labels))} | |
colors = [color_dict[label] for label in labels] | |
source = ColumnDataSource(data=dict( | |
x=reduced_embeddings[:, 0], | |
y=reduced_embeddings[:, 1], | |
desc=descriptions, | |
imgs=images, | |
colors=colors | |
)) | |
TOOLTIPS = """ | |
<div> | |
<div> | |
<img | |
src="@imgs" height="100" alt="@imgs" width="100" | |
style="float: left; margin: 0px 15px 15px 0px;" | |
border="2" | |
></img> | |
</div> | |
<div> | |
<span style="font-size: 12px; font-weight: bold;">@desc</span> | |
</div> | |
</div> | |
""" | |
p = figure(width=800, height=800, tooltips=TOOLTIPS, | |
title="Mouse over the dots") | |
p.circle('x', 'y', size=10, source=source, color='colors', line_color=None) | |
st.bokeh_chart(p) | |
# inster time series graph for clusters sorted by size (except cluster -1, show top5 by default, but include selectbox. reuse resample_dict for binning) | |
st.markdown("### Cluster Size") | |
cluster_sizes = df_clustered.groupby('cluster').size().reset_index(name='counts') | |
cluster_sizes = cluster_sizes.sort_values(by='counts', ascending=False) | |
cluster_sizes = cluster_sizes[cluster_sizes['cluster'] != -1] | |
cluster_sizes = cluster_sizes.set_index('cluster').reset_index() | |
cluster_sizes = cluster_sizes.rename(columns={'cluster': 'Cluster', 'counts': 'Size'}) | |
st.dataframe(cluster_sizes) | |
st.markdown("### Cluster Time Series") | |
# Dropdown to select variables | |
variable = st.selectbox('Select Variable', ['Likes', 'Comments', 'Followers at Posting', 'Total Interactions']) | |
# Dropdown to select time resampling | |
resample_dict = { | |
'Day': 'D', | |
'Three Days': '3D', | |
'Week': 'W', | |
'Two Weeks': '2W', | |
'Month': 'M', | |
'Quarter': 'Q', | |
'Year': 'Y' | |
} | |
# Dropdown to select time resampling | |
resample_time = st.selectbox('Select Time Resampling', list(resample_dict.keys())) | |
# Slider for date range selection | |
min_date = df_clustered['Post Created'].min().date() | |
max_date = df_clustered['Post Created'].max().date() | |
date_range = st.slider('Select Date Range', min_value=min_date, max_value=max_date, value=(min_date, max_date)) | |
# Filter dataframe based on selected date range | |
df_resampled = df_clustered[(df_clustered['Post Created'].dt.date >= date_range[0]) & (df_clustered['Post Created'].dt.date <= date_range[1])] | |
df_resampled = df_resampled.set_index('Post Created') | |
# Get unique clusters and their sizes | |
cluster_sizes = df_resampled[df_resampled['cluster'] != -1]['cluster'].value_counts() | |
clusters = cluster_sizes.index | |
# Select the largest 5 clusters by default | |
default_clusters = cluster_sizes.sort_values(ascending=False).head(5).index.tolist() | |
# Multiselect widget to choose clusters | |
selected_clusters = st.multiselect('Select Clusters', options=clusters.tolist(), default=default_clusters) | |
# Create a new DataFrame for the plot | |
df_plot = pd.DataFrame() | |
# Loop through selected clusters | |
for cluster in selected_clusters: | |
# Create a separate DataFrame for each cluster, resample and add to the plot DataFrame | |
df_cluster = df_resampled[df_resampled['cluster'] == cluster][variable].resample(resample_dict[resample_time]).sum() | |
df_plot = pd.concat([df_plot, df_cluster], axis=1) | |
# Add legend (use cluster numbers as legend) | |
df_plot.columns = selected_clusters | |
# Create the line chart | |
st.line_chart(df_plot) | |
elif selected_menu_option == "Stats": | |
st.markdown("### Time Series Analysis") | |
# Dropdown to select variables | |
variable = st.selectbox('Select Variable', ['Followers at Posting', 'Total Interactions', 'Likes', 'Comments']) | |
# Dropdown to select time resampling | |
resample_dict = { | |
'Day': 'D', | |
'Three Days': '3D', | |
'Week': 'W', | |
'Two Weeks': '2W', | |
'Month': 'M', | |
'Quarter': 'Q', | |
'Year': 'Y' | |
} | |
# Dropdown to select time resampling | |
resample_time = st.selectbox('Select Time Resampling', list(resample_dict.keys())) | |
df_filtered = df.set_index('Post Created') | |
# Slider for date range selection | |
min_date = df_filtered.index.min().date() | |
max_date = df_filtered.index.max().date() | |
date_range = st.slider('Select Date Range', min_value=min_date, max_value=max_date, value=(min_date, max_date)) | |
# Filter dataframe based on selected date range | |
df_filtered = df_filtered[(df_filtered.index.date >= date_range[0]) & (df_filtered.index.date <= date_range[1])] | |
# Create a separate DataFrame for resampling and plotting | |
df_resampled = df_filtered[variable].resample(resample_dict[resample_time]).sum() | |
st.line_chart(df_resampled) | |
st.markdown("### Correlation Analysis") | |
# Dropdown to select variables for scatter plot | |
options = ['Followers at Posting', 'Total Interactions', 'Likes', 'Comments'] | |
scatter_variable_1 = st.selectbox('Select Variable 1 for Scatter Plot', options) | |
# options.remove(scatter_variable_1) # remove the chosen option from the list | |
scatter_variable_2 = st.selectbox('Select Variable 2 for Scatter Plot', options) | |
# Plot scatter chart | |
st.write(f"Scatter Plot of {scatter_variable_1} vs {scatter_variable_2}") | |
# Plot scatter chart | |
scatter_fig = px.scatter(df_filtered, x=scatter_variable_1, y=scatter_variable_2) #, trendline='ols', trendline_color_override='red') | |
st.plotly_chart(scatter_fig) | |
# calculate correlation for scatter_variable_1 with scatter_variable_2 | |
corr = df_filtered[scatter_variable_1].corr(df_filtered[scatter_variable_2]) | |
if corr > 0.7: | |
st.write(f"The correlation coefficient is {corr}, indicating a strong positive relationship between {scatter_variable_1} and {scatter_variable_2}.") | |
elif corr > 0.3: | |
st.write(f"The correlation coefficient is {corr}, indicating a moderate positive relationship between {scatter_variable_1} and {scatter_variable_2}.") | |
elif corr > -0.3: | |
st.write(f"The correlation coefficient is {corr}, indicating a weak or no relationship between {scatter_variable_1} and {scatter_variable_2}.") | |
elif corr > -0.7: | |
st.write(f"The correlation coefficient is {corr}, indicating a moderate negative relationship between {scatter_variable_1} and {scatter_variable_2}.") | |
else: | |
st.write(f"The correlation coefficient is {corr}, indicating a strong negative relationship between {scatter_variable_1} and {scatter_variable_2}.") |