rjadr's picture
Update app.py
690e825 verified
import pandas as pd
import streamlit as st
import datasets
import plotly.express as px
from transformers import AutoProcessor, AutoModel
from PIL import Image
import os
from pandas.api.types import (
is_categorical_dtype,
is_datetime64_any_dtype,
is_numeric_dtype,
is_object_dtype,
)
import subprocess
from tempfile import NamedTemporaryFile
from itertools import combinations
import networkx as nx
import plotly.graph_objects as go
import colorcet as cc
from matplotlib.colors import rgb2hex
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.decomposition import PCA
import hdbscan
import umap
import numpy as np
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource
from datetime import datetime
import re
#st.set_page_config(layout="wide")
model_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
token_ = st.secrets["token"]
@st.cache_resource(show_spinner=True)
def load_model(model_name):
"""
Load the model and processor
"""
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
return processor, model
@st.cache_data(show_spinner=True)
def load_dataset():
dataset = datasets.load_dataset('rjadr/ditaduranuncamais', split='train', token=token_)
dataset.add_faiss_index(column="text_embs")
dataset.add_faiss_index(column="img_embs")
dataset = dataset.remove_columns(['Post Created Date', 'Post Created Time','Like and View Counts Disabled','Link','Download URL','Views'])
return dataset
@st.cache_data(show_spinner=False)
def load_dataframe(_dataset):
dataframe = _dataset.remove_columns(['text_embs', 'img_embs']).to_pandas()
# Extract hashtags with regex and convert to set
dataframe['Hashtags'] = dataframe.apply(lambda row: f"{row['Description']} {row['Image Text']}", axis=1)
dataframe['Hashtags'] = dataframe['Hashtags'].str.lower().str.findall(r'#(\w+)').apply(set)
# Create a cleaned description column up-front
dataframe['description_clean'] = dataframe['Description'].apply(clean_and_truncate_text)
# Reorder columns to keep the new column next to the original
dataframe = dataframe[['Post Created', 'image', 'Description', 'description_clean', 'Image Text', 'Account', 'User Name'] + [col for col in dataframe.columns if col not in ['Post Created', 'image', 'Description', 'description_clean', 'Image Text', 'Account', 'User Name']]]
return dataframe
def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame:
"""
Adds a UI on top of a dataframe to let viewers filter columns
Args:
df (pd.DataFrame): Original dataframe
Returns:
pd.DataFrame: Filtered dataframe
"""
modify = st.checkbox("Add filters")
if not modify:
return df
df = df.copy()
# Try to convert datetimes into a standard format (datetime, no timezone)
for col in df.columns:
if is_object_dtype(df[col]):
try:
df[col] = pd.to_datetime(df[col])
except Exception:
pass
if is_datetime64_any_dtype(df[col]):
df[col] = df[col].dt.tz_localize(None)
modification_container = st.container()
with modification_container:
to_filter_columns = st.multiselect("Filter dataframe on", df.columns)
for column in to_filter_columns:
left, right = st.columns((1, 20))
left.write("↳")
# Treat columns with < 10 unique values as categorical
if is_categorical_dtype(df[column]) or df[column].nunique() < 10:
user_cat_input = right.multiselect(
f"Values for {column}",
df[column].unique(),
default=list(df[column].unique()),
)
df = df[df[column].isin(user_cat_input)]
elif is_numeric_dtype(df[column]):
_min = float(df[column].min())
_max = float(df[column].max())
step = (_max - _min) / 100
user_num_input = right.slider(
f"Values for {column}",
_min,
_max,
(_min, _max),
step=step,
)
df = df[df[column].between(*user_num_input)]
elif is_datetime64_any_dtype(df[column]):
user_date_input = right.date_input(
f"Values for {column}",
value=(
df[column].min(),
df[column].max(),
),
)
if len(user_date_input) == 2:
user_date_input = tuple(map(pd.to_datetime, user_date_input))
start_date, end_date = user_date_input
df = df.loc[df[column].between(start_date, end_date)]
else:
user_text_input = right.text_input(
f"Substring or regex in {column}",
)
if user_text_input:
df = df[df[column].str.contains(user_text_input)]
return df
@st.cache_data
def get_image_embs(_processor, _model, uploaded_file):
"""
Get image embeddings
Parameters:
processor (transformers.AutoProcessor): Processor for the model
model (transformers.AutoModel): Model to use for embeddings
uploaded_file (PIL.Image): Uploaded image file
Returns:
img_emb (np.array): Image embeddings
"""
# Load the image from local path
image = Image.open(uploaded_file)
# Process the image
inputs = _processor(images=image, return_tensors="pt")
# Forward pass without gradient calculation
outputs = _model.get_image_features(**inputs)
# Normalize the image embeddings
img_embs = outputs / outputs.norm(dim=-1, keepdim=True)
# Convert to list and add to example
img_emb = img_embs.squeeze(0).detach().cpu().numpy()
return img_emb
@st.cache_data(show_spinner=False)
def get_text_embs(_processor, _model, text):
"""
Get text embeddings
Parameters:
processor (transformers.AutoProcessor): Processor for the model
model (transformers.AutoModel): Model to use for embeddings
text (str): Text to encode
Returns:
text_emb (np.array): Text embeddings
"""
# Process the text with truncation
inputs = _processor(
text=text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=77 # CLIP's maximum sequence length
)
# Forward pass without gradient calculation
outputs = _model.get_text_features(**inputs)
# Normalize the text embeddings
text_embs = outputs / outputs.norm(dim=-1, keepdim=True)
# Convert to list and add to example
txt_emb = text_embs.squeeze(0).detach().cpu().numpy()
return txt_emb
@st.cache_data
def postprocess_results(scores, samples):
"""
Postprocess results to tuple of labels and scores
Parameters:
scores (np.array): Scores
samples (datasets.Dataset): Samples
Returns:
labels (list): List of tuples of PIL images and labels/scores
"""
samples_df = pd.DataFrame.from_dict(samples)
samples_df["score"] = scores
samples_df["score"] = (1 - (samples_df["score"] - samples_df["score"].min()) / (
samples_df["score"].max() - samples_df["score"].min())) * 100
samples_df["score"] = samples_df["score"].astype(int)
samples_df.reset_index(inplace=True, drop=True)
samples_df = samples_df[['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name'] + [col for col in samples_df.columns if col not in ['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name']]]
return samples_df.drop(columns=['text_embs', 'img_embs'])
@st.cache_data
def text_to_text(text, k=5):
"""
Text to text
Parameters:
text (str): Input text
k (int): Number of top results to return
Returns:
results (list): List of tuples of PIL images and labels/scores
"""
text_emb = get_text_embs(processor, model, text)
scores, samples = dataset.get_nearest_examples('text_embs', text_emb, k=k)
return postprocess_results(scores, samples)
@st.cache_data
def image_to_text(image, k=5):
"""
Image to text
Parameters:
image (str): Temp filepath to image
k (int): Number of top results to return
Returns:
results (list): List of tuples of PIL images and labels/scores
"""
img_emb = get_image_embs(processor, model, image.name)
scores, samples = dataset.get_nearest_examples('text_embs', img_emb, k=k)
return postprocess_results(scores, samples)
@st.cache_data
def text_to_image(text, k=5):
"""
Text to image
Parameters:
text (str): Input text
k (int): Number of top results to return
Returns:
results (list): List of tuples of PIL images and labels/scores
"""
text_emb = get_text_embs(processor, model, text)
scores, samples = dataset.get_nearest_examples('img_embs', text_emb, k=k)
return postprocess_results(scores, samples)
@st.cache_data
def image_to_image(image, k=5):
"""
Image to image
Parameters:
image (str): Temp filepath to image
k (int): Number of top results to return
Returns:
results (list): List of tuples of PIL images and labels/scores
"""
img_emb = get_image_embs(processor, model, image.name)
scores, samples = dataset.get_nearest_examples('img_embs', img_emb, k=k)
return postprocess_results(scores, samples)
def disparity_filter(g: nx.Graph, weight: str = 'weight', alpha: float = 0.05) -> nx.Graph:
"""
Computes the backbone of the input graph using the disparity filter algorithm.
The algorithm is proposed in:
M. A. Serrano, M. Boguna, and A. Vespignani,
"Extracting the Multiscale Backbone of Complex Weighted Networks",
PNAS, 106(16), pp 6483--6488 (2009).
DOI: 10.1073/pnas.0808904106
Implementation taken from https://groups.google.com/g/networkx-discuss/c/bCuHZ3qQ2po/m/QvUUJqOYDbIJ
Parameters
----------
g : NetworkX graph
The input graph.
weight : str, optional (default='weight')
The name of the edge attribute to use as weight.
alpha : float, optional (default=0.05)
The statistical significance level for the disparity filter (p-value).
Returns
-------
backbone_graph : NetworkX graph
The backbone graph.
"""
# Create an empty graph for the backbone
backbone_graph = nx.Graph()
# Iterate over all nodes in the input graph
for node in g:
# Get the degree of the node (number of edges connected to the node)
k_n = len(g[node])
# Only proceed if the node has more than one connection
if k_n > 1:
# Calculate the sum of weights of edges connected to the node
sum_w = sum(g[node][neighbor][weight] for neighbor in g[node])
# Iterate over all neighbors of the node
for neighbor in g[node]:
# Get the weight of the edge between the node and its neighbor
edge_weight = g[node][neighbor][weight]
# Calculate the proportion of the total weight that this edge represents
pij = float(edge_weight) / sum_w
# Perform the disparity filter test. If it passes, the edge is considered significant and is added to the backbone
if (1 - pij) ** (k_n - 1) < alpha:
backbone_graph.add_edge(node, neighbor, weight=edge_weight)
# Return the backbone graph
return backbone_graph
st.cache_data(show_spinner=True)
def assign_community_colors(G: nx.Graph, attr: str = 'community') -> dict:
"""
Assigns a unique color to each community in the input graph.
Parameters
----------
G : nx.Graph
The input graph.
attr : str, optional
The node attribute of the community names or indexes (default is 'community').
Returns
-------
dict
A dictionary mapping each community to a unique color.
"""
glasbey_colors = cc.glasbey_hv
communities_ = set(nx.get_node_attributes(G, attr).values())
return {community: rgb2hex(glasbey_colors[i % len(glasbey_colors)]) for i, community in enumerate(communities_)}
st.cache_data(show_spinner=True)
def generate_hover_text(G: nx.Graph, attr: str = 'community') -> list:
"""
Generates hover text for each node in the input graph.
Parameters
----------
G : nx.Graph
The input graph.
attr : str, optional
The node attribute of the community names or indexes (default is 'community').
Returns
-------
list
A list of strings containing the hover text for each node.
"""
return [f"Node: {str(node)}<br>Community: {G.nodes[node][attr] + 1}<br># of connections: {len(adjacencies)}" for node, adjacencies in G.adjacency()]
st.cache_data(show_spinner=True)
def calculate_node_sizes(G: nx.Graph) -> list:
"""
Calculates the size of each node in the input graph based on its degree.
Parameters
----------
G : nx.Graph
The input graph.
Returns
-------
list
A list of node sizes.
"""
degrees = dict(G.degree())
max_degree = max(deg for node, deg in degrees.items())
return [10 + 20 * (degrees[node] / max_degree) for node in G.nodes()]
@st.cache_data(show_spinner=True)
def plot_graph(_G: nx.Graph, layout_name: str = "spring", community_names_lookup: dict = None):
"""
Plots a network graph with communities and a legend, using a choice of pure-Python layouts.
Parameters
----------
_G : nx.Graph
The input graph with a 'community' attribute on each node.
layout_name : str, optional
The name of the NetworkX layout algorithm to use.
community_names_lookup : dict, optional
A dictionary mapping community key (e.g., 'Community 1') to a display name.
"""
# --- Select the layout algorithm ---
if layout_name == "kamada_kawai":
# Aesthetically pleasing, can be slow on large graphs.
pos = nx.kamada_kawai_layout(_G, dim=3)
elif layout_name == "circular":
# Fast, simple circle. It's 2D, so we add a Z-coordinate.
pos_2d = nx.circular_layout(_G)
pos = {node: (*coords, 0) for node, coords in pos_2d.items()}
elif layout_name == "spectral":
# Good for showing clusters. Also 2D, so we add a Z-coordinate.
pos_2d = nx.spectral_layout(_G)
pos = {node: (*coords, 0) for node, coords in pos_2d.items()}
else: # Default to "spring"
# The standard physics-based layout.
pos = nx.spring_layout(_G, dim=3, k=0.15, iterations=50, seed=779)
# --- Generate colors and traces (this part remains the same) ---
communities = sorted(list(set(nx.get_node_attributes(_G, 'community').values())))
community_colors = {comm: color for comm, color in zip(communities, cc.glasbey_hv)}
edge_x, edge_y, edge_z = [], [], []
for edge in _G.edges():
x0, y0, z0 = pos[edge[0]]
x1, y1, z1 = pos[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
edge_z.extend([z0, z1, None])
edge_trace = go.Scatter3d(
x=edge_x, y=edge_y, z=edge_z,
line=dict(width=0.5, color='#888'),
hoverinfo='none',
mode='lines')
data = [edge_trace]
for comm_idx in communities:
comm_key = f'Community {comm_idx + 1}'
comm_name = community_names_lookup.get(comm_key, comm_key)
node_x, node_y, node_z, node_text = [], [], [], []
for node in _G.nodes():
if _G.nodes[node]['community'] == comm_idx:
x, y, z = pos[node]
node_x.append(x)
node_y.append(y)
node_z.append(z)
node_text.append(f"Hashtag: #{node}<br>Community: {comm_name}")
node_trace = go.Scatter3d(
x=node_x, y=node_y, z=node_z,
mode='markers',
name=comm_name,
marker=dict(
symbol='circle',
size=7,
color=rgb2hex(community_colors[comm_idx]),
line=dict(color='rgb(50,50,50)', width=0.5)
),
text=node_text,
hoverinfo='text'
)
data.append(node_trace)
# --- Layout (remains the same) ---
layout = go.Layout(
title="3D Hashtag Network Graph",
showlegend=True,
legend=dict(title="Communities", x=1.05, y=0.5),
width=1000,
height=800,
margin=dict(l=0, r=0, b=0, t=40),
scene=dict(
xaxis=dict(showbackground=False, showline=False, zeroline=False, showgrid=False, showticklabels=False, title=''),
yaxis=dict(showbackground=False, showline=False, zeroline=False, showgrid=False, showticklabels=False, title=''),
zaxis=dict(showbackground=False, showline=False, zeroline=False, showgrid=False, showticklabels=False, title='')
)
)
fig = go.Figure(data=data, layout=layout)
return fig
def clean_and_truncate_text(text, max_length=30):
"""
Removes hashtags and truncates text to a specified length.
Args:
text (str): The input string to clean.
max_length (int): The maximum length of the output string.
Returns:
str: The cleaned and truncated string.
"""
if not isinstance(text, str):
return "" # Return empty string for non-string inputs
# Use regex to remove hashtags (words starting with #)
no_hashtags = re.sub(r'#\w+\s*', '', text).strip()
# Truncate the string if it's too long
if len(no_hashtags) > max_length:
return no_hashtags[:max_length] + '...'
else:
return no_hashtags
@st.cache_data(show_spinner=True)
def cluster_embeddings(embeddings, clustering_algo='KMeans', dim_reduction='PCA',
# KMeans & MiniBatchKMeans params
n_clusters=5, batch_size=256, max_iter=100,
# HDBSCAN params
min_cluster_size=5, min_samples=5,
# Reducer params
n_components=2, n_neighbors=15, min_dist=0.0, random_state=42):
"""Performs dimensionality reduction and clustering on a set of embeddings.
This function chains two steps: first, it reduces the dimensionality of the
input embeddings using either PCA or UMAP. Second, it applies a clustering
algorithm (KMeans, MiniBatchKMeans, or HDBSCAN) to the reduced-dimensional
data to assign a cluster label to each embedding.
Args:
embeddings (list or np.ndarray): A list or array of high-dimensional
embedding vectors. Each element should be a 1D NumPy array.
clustering_algo (str, optional): The clustering algorithm to use.
Options are 'KMeans', 'MiniBatchKMeans', or 'HDBSCAN'.
Defaults to 'KMeans'.
dim_reduction (str, optional): The dimensionality reduction method to use.
Options are 'PCA' or 'UMAP'. Defaults to 'PCA'.
n_clusters (int, optional): The number of clusters to form. Used by
KMeans and MiniBatchKMeans. Defaults to 5.
batch_size (int, optional): The size of mini-batches for MiniBatchKMeans.
Defaults to 256.
max_iter (int, optional): The maximum number of iterations for
MiniBatchKMeans. Defaults to 100.
min_cluster_size (int, optional): The minimum number of samples in a
group for it to be considered a cluster. Used by HDBSCAN.
Defaults to 5.
min_samples (int, optional): The number of samples in a neighborhood for
a point to be considered a core point. Used by HDBSCAN.
Defaults to 5.
n_components (int, optional): The number of dimensions to reduce to.
Used by PCA and UMAP. Defaults to 2.
n_neighbors (int, optional): The number of neighbors to consider for
manifold approximation. Used by UMAP. Defaults to 15.
min_dist (float, optional): The effective minimum distance between
embedded points. Used by UMAP. Defaults to 0.0.
random_state (int, optional): The seed used by the random number
generator for reproducibility. Defaults to 42.
Returns:
tuple: A tuple containing:
- np.ndarray: An array of cluster labels assigned to each embedding.
- np.ndarray: The reduced-dimensional representation of the embeddings.
Raises:
ValueError: If an invalid `clustering_algo` or `dim_reduction` method
is specified.
"""
# Stack embeddings into a single NumPy array
data_array = np.stack(embeddings)
# --- 1. Dimensionality Reduction ---
if dim_reduction == 'PCA':
reducer = PCA(n_components=n_components, random_state=random_state)
elif dim_reduction == 'UMAP':
reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components, random_state=random_state)
else:
raise ValueError('Invalid dimensionality reduction method')
reduced_embeddings = reducer.fit_transform(data_array)
# --- 2. Clustering ---
if clustering_algo == 'MiniBatchKMeans':
# Use the specific parameters for MiniBatchKMeans
clusterer = MiniBatchKMeans(
n_clusters=n_clusters,
random_state=random_state,
batch_size=batch_size,
max_iter=max_iter,
n_init='auto' # Recommended setting
)
elif clustering_algo == 'KMeans':
clusterer = KMeans(n_clusters=n_clusters, random_state=random_state, n_init='auto')
elif clustering_algo == 'HDBSCAN':
clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples)
else:
raise ValueError('Invalid clustering algorithm')
labels = clusterer.fit_predict(reduced_embeddings)
return labels, reduced_embeddings
st.title("#ditaduranuncamais Data Explorer")
def check_password():
"""Returns `True` if user is authenticated, `False` otherwise."""
# If the user is already authenticated, just return True.
# This is the most important part: we don't render the password form again.
if st.session_state.get("password_correct", False):
return True
# This part of the code will only run if the user is not yet authenticated.
def password_entered():
"""Checks whether the password entered is correct."""
if st.session_state.get("password") == st.secrets.get("password"):
st.session_state["password_correct"] = True
# Don't store the password in session state.
del st.session_state["password"]
else:
st.session_state["password_correct"] = False
# Show the password input form.
st.text_input(
"Password", type="password", on_change=password_entered, key="password"
)
# Show an error message if the last attempt was incorrect.
# The 'in' check prevents the error from showing on the first load.
if "password_correct" in st.session_state and not st.session_state.password_correct:
st.error("😕 Password incorrect")
# Return False to stop the main app from running.
return False
if not check_password():
st.stop()
# Check if the directory exists
dataset = load_dataset()
df = load_dataframe(dataset)
processor, model = load_model(model_name)
#image_model = load_img_model()
#text_model = load_txt_model()
menu_options = ["Data exploration", "Semantic search", "Hashtags", "Clustering", "Stats"]
st.sidebar.markdown('# Menu')
selected_menu_option = st.sidebar.radio("Select a page", menu_options)
if selected_menu_option == "Data exploration":
st.dataframe(
data=filter_dataframe(df),
# use_container_width=True,
column_config={
"image": st.column_config.ImageColumn(
"Image", help="Instagram image"
),
"URL": st.column_config.LinkColumn(
"Link", help="Instagram link", width="small"
)
},
hide_index=True,
)
elif selected_menu_option == "Semantic search":
tabs = ["Text to Text", "Text to Image", "Image to Image", "Image to Text"]
selected_tab = st.sidebar.radio("Select a search type", tabs)
if selected_tab == "Text to Text":
st.markdown('## Text to text search')
text_to_text_input = st.text_input("Enter text")
text_to_text_k_top = st.slider("Number of results", 1, 500, 20)
if st.button("Search"):
if not text_to_text_input:
st.warning("Please enter text")
else:
st.dataframe(
data=text_to_text(text_to_text_input, text_to_text_k_top),
column_config={
"image": st.column_config.ImageColumn(
"Image", help="Instagram image"
),
"URL": st.column_config.LinkColumn(
"Link", help="Instagram link", width="small"
)
},
hide_index=True,
)
elif selected_tab == "Text to Image":
st.markdown('## Text to image search')
text_to_image_input = st.text_input("Enter text")
text_to_image_k_top = st.slider("Number of results", 1, 500, 20)
if st.button("Search"):
if not text_to_image_input:
st.warning("Please enter some text")
else:
st.dataframe(
data=text_to_image(text_to_image_input, text_to_image_k_top),
column_config={
"image": st.column_config.ImageColumn(
"Image", help="Instagram image"
),
"URL": st.column_config.LinkColumn(
"Link", help="Instagram link", width="small"
)
},
hide_index=True,
)
elif selected_tab == "Image to Image":
st.markdown('## Image to image search')
image_to_image_k_top = st.slider("Number of results", 1, 500, 20)
image_to_image_input = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
temp_file = NamedTemporaryFile(delete=False)
if st.button("Search"):
if not image_to_image_input:
st.warning("Please upload an image")
else:
temp_file.write(image_to_image_input.getvalue())
st.dataframe(
data=image_to_image(temp_file, image_to_image_k_top),
column_config={
"image": st.column_config.ImageColumn(
"Image", help="Instagram image"
),
"URL": st.column_config.LinkColumn(
"Link", help="Instagram link", width="small"
)
},
hide_index=True,
)
elif selected_tab == "Image to Text":
st.markdown('## Image to text search')
image_to_text_k_top = st.slider("Number of results", 1, 500, 20)
image_to_text_input = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
temp_file = NamedTemporaryFile(delete=False)
if st.button("Search"):
if not image_to_text_input:
st.warning("Please upload an image")
else:
temp_file.write(image_to_text_input.getvalue())
st.dataframe(
data=image_to_text(temp_file, image_to_text_k_top),
column_config={
"image": st.column_config.ImageColumn(
"Image", help="Instagram image"
),
"URL": st.column_config.LinkColumn(
"Link", help="Instagram link", width="small"
)
},
hide_index=True,
)
elif selected_menu_option == "Hashtags":
st.markdown("### Hashtag Co-occurrence Analysis")
st.markdown("This section creates a network of hashtags based on how often they are used together. Use the sidebar to configure the analysis, then click the button to generate the network and identify communities.")
# --- Sidebar Configuration (no changes) ---
if 'dfx' not in st.session_state:
st.session_state.dfx = df.copy()
all_hashtags = sorted(list(set(item for sublist in st.session_state.dfx['Hashtags'] for item in sublist)))
st.sidebar.markdown('## Hashtag Network Options')
hashtags_to_remove = st.sidebar.multiselect("Hashtags to remove", all_hashtags)
col1, col2 = st.sidebar.columns(2)
if col1.button("Remove hashtags"):
st.session_state.dfx['Hashtags'] = st.session_state.dfx['Hashtags'].apply(lambda x: [item for item in x if item not in hashtags_to_remove])
if 'hashtag_results' in st.session_state:
del st.session_state.hashtag_results
st.rerun()
if col2.button("Reset Hashtags"):
st.session_state.dfx = df.copy()
if 'hashtag_results' in st.session_state:
del st.session_state.hashtag_results
st.rerun()
weight_option = st.sidebar.radio(
'Select weight definition',
('Number of users that use the hashtag pairs', 'Total number of occurrences')
)
# --- Main Page Content ---
if st.button("Generate Hashtag Network", type="primary"):
with st.spinner("Building graph, filtering edges, and detecting communities..."):
# (Calculation code remains the same as before...)
hashtag_user_pairs = [(tuple(sorted(combination)), userid) for hashtags, userid in zip(st.session_state.dfx['Hashtags'], st.session_state.dfx['User Name']) for combination in combinations(hashtags, r=2)]
hashtag_user_df = pd.DataFrame(hashtag_user_pairs, columns=['hashtag_pair', 'User Name'])
if weight_option == 'Number of users that use the hashtag pairs':
edge_df = hashtag_user_df.groupby('hashtag_pair').agg({'User Name': 'nunique'}).reset_index()
else:
edge_df = hashtag_user_df.groupby('hashtag_pair').size().reset_index(name='User Name')
edge_df = edge_df.rename(columns={'User Name': 'weight'})
edge_df[['hashtag1', 'hashtag2']] = pd.DataFrame(edge_df['hashtag_pair'].tolist(), index=edge_df.index)
edge_list = edge_df[['hashtag1', 'hashtag2', 'weight']]
G = nx.from_pandas_edgelist(edge_list, 'hashtag1', 'hashtag2', 'weight')
G_backbone = disparity_filter(G, weight='weight', alpha=0.05)
communities = list(nx.community.louvain_communities(G_backbone, weight='weight', seed=1234))
communities.sort(key=len, reverse=True)
for i, community in enumerate(communities):
for node in community:
G_backbone.nodes[node]['community'] = i
sorted_community_hashtags = pd.DataFrame([
[h for h, _ in sorted(((h, G.degree(h, weight='weight')) for h in com), key=lambda x: x[1], reverse=True)]
for com in communities
]).T
sorted_community_hashtags.columns = [f'Community {i+1}' for i in range(len(sorted_community_hashtags.columns))]
# Initialize the community names dataframe and store it in session state
df_community_names = pd.DataFrame(
sorted_community_hashtags.columns,
columns=['community_names'],
index=sorted_community_hashtags.columns
)
st.session_state.community_names_df = df_community_names
st.session_state.hashtag_results = {
"G_backbone": G_backbone,
"communities": communities,
"sorted_community_hashtags": sorted_community_hashtags,
"edge_list": edge_list
}
st.rerun()
# --- Display Results Section ---
if 'hashtag_results' in st.session_state:
results = st.session_state.hashtag_results
G_backbone = results['G_backbone']
communities = results['communities']
sorted_community_hashtags = results['sorted_community_hashtags']
edge_list = results['edge_list']
st.success(f"Network generated! Found **{len(communities)}** communities from **{len(G_backbone.nodes)}** hashtags and **{len(G_backbone.edges)}** connections.")
# Define the tabs with the editor in its own tab
tab_graph, tab_editor, tab_timeline, tab_lists = st.tabs([
"📊 Network Graph",
"📝 Edit Community Names",
"🕒 Community Timelines",
"📋 Community & Edge Lists"
])
with tab_graph:
st.markdown("### Hashtag Network Graph")
st.markdown("Nodes represent hashtags, colored by community. The legend uses the names from the 'Edit Community Names' tab.")
# Re-introduce the layout selector with safe, pure-Python options
layout_options = {
"Spring": "spring",
"Kamada-Kawai": "kamada_kawai",
"Circular": "circular",
"Spectral": "spectral"
}
selected_layout_name = st.selectbox(
"Graph Layout Algorithm",
options=layout_options.keys()
)
# Get the actual function name string
layout_alg_str = layout_options[selected_layout_name]
# Retrieve edited names from session state
community_names_lookup = st.session_state.community_names_df['community_names'].to_dict()
# Call the plot function with the chosen layout
fig = plot_graph(
_G=G_backbone,
layout_name=layout_alg_str,
community_names_lookup=community_names_lookup
)
st.plotly_chart(fig, use_container_width=True)
with tab_editor:
st.markdown("### Edit Community Names")
st.markdown("Change the default community names in the table below. The new names will automatically update the graph legend and the timeline chart.")
# The data editor modifies the dataframe in session_state
edited_df = st.data_editor(
st.session_state.community_names_df,
use_container_width=True,
num_rows="dynamic" # Allows for adding/removing if needed, though less likely here
)
# Persist any changes back to session state
st.session_state.community_names_df = edited_df
st.download_button(
label="Download Community Names as CSV",
data=edited_df.to_csv().encode("utf-8"),
file_name="community_names.csv",
mime="text/csv",
)
with tab_timeline:
st.markdown("### Community Size Over Time")
# Retrieve the latest names from session state for the multiselect options
community_names_lookup = st.session_state.community_names_df['community_names'].to_dict()
selected_communities = st.multiselect('Select Communities', community_names_lookup.values(), default=list(community_names_lookup.values()))
resample_dict = {'Day': 'D', 'Week': 'W', 'Month': 'M', 'Quarter': 'Q', 'Year': 'Y'}
resample_time = st.selectbox('Select Time Resampling', list(resample_dict.keys()), index=4)
community_dict = {node: community_names_lookup.get(f'Community {i+1}') for i, comm_set in enumerate(communities) for node in comm_set}
df_communities = st.session_state.dfx.copy()
df_communities['Communities'] = df_communities['Hashtags'].apply(lambda tags: list(set(community_dict.get(tag) for tag in tags if tag in community_dict)))
df_communities = df_communities.explode('Communities').dropna(subset=['Communities'])
df_ts = df_communities.set_index('Post Created')
df_community_sizes = df_ts.groupby([pd.Grouper(freq=resample_dict[resample_time]), 'Communities']).size().unstack(fill_value=0)
existing_selected_cols = [col for col in selected_communities if col in df_community_sizes.columns]
if existing_selected_cols:
st.area_chart(df_community_sizes[existing_selected_cols])
else:
st.warning("No data available for the selected communities.")
with tab_lists:
st.markdown("### Hashtag Communities (by importance)")
st.dataframe(sorted_community_hashtags)
st.markdown("### Top Edge Pairs (by weight)")
st.dataframe(edge_list.sort_values(by='weight', ascending=False).head(100))
elif selected_menu_option == "Clustering":
st.markdown("## Clustering of Posts")
st.markdown("This section allows you to group posts based on the similarity of their text or image content. Use the sidebar to configure the clustering process, then click 'Run Clustering' to see the results.")
# --- Sidebar Configuration (no changes here) ---
st.sidebar.markdown("# Clustering Options")
st.sidebar.markdown("### Data & Algorithm")
type_embeddings = st.sidebar.selectbox("Cluster based on:", ["Image", "Text"])
clustering_algo = st.sidebar.selectbox("Clustering Algorithm:", ["MiniBatchKMeans", "HDBSCAN", "KMeans"])
st.sidebar.info(f"**Tip:** `MiniBatchKMeans` is the fastest for a quick overview.")
st.sidebar.markdown("### Algorithm Settings")
if clustering_algo in ["KMeans", "MiniBatchKMeans"]:
n_clusters = st.sidebar.slider("Number of Clusters (k)", 2, 50, 5, key="n_clusters_slider")
if clustering_algo == "MiniBatchKMeans":
batch_size = st.sidebar.slider("Batch Size", 32, 1024, 256, 32, help="Number of samples to use in each mini-batch.")
max_iter = st.sidebar.slider("Max Iterations", 50, 500, 100, 50, help="Maximum number of iterations.")
else:
batch_size, max_iter = None, None
min_cluster_size, min_samples = None, None
elif clustering_algo == "HDBSCAN":
min_cluster_size = st.sidebar.slider("Minimum Cluster Size", 2, 200, 15, help="Smallest size for a group to be a cluster.")
min_samples = st.sidebar.slider("Minimum Samples", 1, 50, 5, help="Larger values lead to more points being declared as noise.")
n_clusters, batch_size, max_iter = None, None, None
st.sidebar.markdown("### Dimensionality Reduction")
dim_reduction = st.sidebar.selectbox("Reduction Method:", ["PCA", "UMAP"])
st.sidebar.info(f"**Tip:** `PCA` is much faster than `UMAP`.")
if dim_reduction == "UMAP":
n_components = st.sidebar.slider("Number of Components", 2, 80, 50, help="Dimensions to reduce to before clustering.")
n_neighbors = st.sidebar.slider("Number of Neighbors", 2, 50, 15, help="Controls UMAP's balance of local/global structure.")
min_dist = st.sidebar.slider("Minimum Distance", 0.0, 1.0, 0.0, help="Controls how tightly UMAP packs points.")
else:
n_components = st.sidebar.slider("Number of Components", 2, 80, 2)
n_neighbors, min_dist = None, None
# --- Main Page Content ---
# 1. Add a button to trigger the expensive computation
if st.button("Run Clustering", type="primary"):
with st.spinner("Clustering embeddings... This may take a moment."):
if type_embeddings == "Text":
embeddings = dataset['text_embs']
else: # Image
embeddings = dataset['img_embs']
# Call the expensive function here
labels, reduced_embeddings = cluster_embeddings(
embeddings,
clustering_algo=clustering_algo,
dim_reduction=dim_reduction,
n_clusters=n_clusters,
min_cluster_size=min_cluster_size,
n_components=n_components,
n_neighbors=n_neighbors,
min_dist=min_dist,
min_samples=min_samples,
batch_size=batch_size,
max_iter=max_iter
)
# 2. Store the results in session state
st.session_state['cluster_results'] = {
"labels": labels,
"reduced_embeddings": reduced_embeddings,
"type_embeddings": type_embeddings,
"clustering_algo": clustering_algo,
"dim_reduction": dim_reduction
}
st.rerun() # Rerun to display results immediately after calculation
# 3. Only show results if they exist in session state
if 'cluster_results' in st.session_state:
# Unpack results from session state
results = st.session_state['cluster_results']
labels = results['labels']
reduced_embeddings = results['reduced_embeddings']
num_found_clusters = len(set(labels) - {-1})
st.success(f"Clustering complete! Found **{num_found_clusters}** clusters using **{results['clustering_algo']}** on **{results['type_embeddings']}** embeddings with **{results['dim_reduction']}** reduction.")
df_clustered = df.copy()
df_clustered['cluster'] = labels
# 4. Use tabs to organize the output
tab1, tab2, tab3 = st.tabs(["📊 Results Table", "📈 2D Visualization", "🕒 Time Series Analysis"])
with tab1:
st.markdown("### Clustered Data")
st.dataframe(
data=filter_dataframe(df_clustered),
column_config={
"image": st.column_config.ImageColumn("Image", help="Instagram image"),
"URL": st.column_config.LinkColumn("Link", help="Instagram link", width="small")
},
hide_index=True,
use_container_width=True
)
st.download_button(
"Download Clustered Data as CSV",
df_clustered.to_csv(index=False).encode('utf-8'),
f'clustered_data_{datetime.now().strftime("%Y%m%d-%H%M%S")}.csv',
"text/csv",
key='download-cluster-csv'
)
with tab2:
st.markdown("### Cluster Visualization")
if reduced_embeddings.shape[1] > 2:
with st.spinner("Reducing dimensions for 2D visualization..."):
vis_reducer = umap.UMAP(n_components=2, random_state=42)
vis_embeddings = vis_reducer.fit_transform(reduced_embeddings)
else:
vis_embeddings = reduced_embeddings
df_plot_bokeh = pd.DataFrame(vis_embeddings, columns=('x', 'y'))
df_plot_bokeh['description_clean'] = df_clustered['description_clean']
df_plot_bokeh['image_url'] = df_clustered['image']
df_plot_bokeh['cluster'] = labels
unique_labels = sorted(list(set(labels)))
color_dict = {label: rgb2hex(cc.glasbey_hv[i % len(cc.glasbey_hv)]) for i, label in enumerate(unique_labels)}
df_plot_bokeh['color'] = df_plot_bokeh['cluster'].map(color_dict)
source = ColumnDataSource(data=df_plot_bokeh)
TOOLTIPS = """
<div style="width: 200px; padding: 5px; background-color: #f0f0f0; border-radius: 5px; font-family: sans-serif; border: 1px solid #cccccc;">
<div>
<img src="@image_url"
height="150"
width="150"
style="display: block; margin: auto;"
border="0">
</img>
</div>
<hr style="border: 1px solid #aaaaaa; margin: 8px 0;">
<div style="text-align: left; padding: 0 5px;">
<span style="font-size: 12px; font-weight: bold;">Cluster: @cluster</span><br>
<span style="font-size: 11px; word-wrap: break-word;">@description_clean</span>
</div>
</div>
"""
p = figure(width=800, height=800, tooltips=TOOLTIPS, title="2D Visualization of Post Clusters")
p.circle('x', 'y', size=10, source=source, color='color', legend_field='cluster', line_color=None, alpha=0.8)
p.legend.title = 'Cluster'
p.legend.location = "top_left"
st.bokeh_chart(p, use_container_width=True)
with tab3:
st.markdown("### Cluster Analysis Over Time")
# Define the dictionary before using it.
resample_dict = {
'Day': 'D',
'Week': 'W',
'Month': 'M',
'Quarter': 'Q',
'Year': 'Y'
}
variable = st.selectbox('Select Variable for Time Series:', ['Likes', 'Comments', 'Followers at Posting', 'Total Interactions'], key="cluster_ts_var")
resample_time = st.selectbox('Resample Time By:', list(resample_dict.keys()), index=2, key="cluster_ts_resample")
df_ts = df_clustered.copy()
df_ts['Post Created'] = pd.to_datetime(df_ts['Post Created'])
df_ts = df_ts.set_index('Post Created')
df_ts = df_ts[df_ts['cluster'] != -1] # Exclude noise points
if not df_ts.empty:
# Use the dictionary to get the correct frequency string ('D', 'W', 'M', etc.)
df_plot = df_ts.groupby([pd.Grouper(freq=resample_dict[resample_time]), 'cluster'])[variable].sum().unstack(fill_value=0)
st.line_chart(df_plot)
else:
st.warning("No data available for plotting (all points may have been classified as noise).")
elif selected_menu_option == "Stats":
st.markdown("### Time Series Analysis")
# Dropdown to select variables
variable = st.selectbox('Select Variable', ['Followers at Posting', 'Total Interactions', 'Likes', 'Comments'])
# Dropdown to select time resampling
resample_dict = {
'Day': 'D',
'Three Days': '3D',
'Week': 'W',
'Two Weeks': '2W',
'Month': 'M',
'Quarter': 'Q',
'Year': 'Y'
}
# Dropdown to select time resampling
resample_time = st.selectbox('Select Time Resampling', list(resample_dict.keys()))
df_filtered = df.set_index('Post Created')
# Slider for date range selection
min_date = df_filtered.index.min().date()
max_date = df_filtered.index.max().date()
date_range = st.slider('Select Date Range', min_value=min_date, max_value=max_date, value=(min_date, max_date))
# Filter dataframe based on selected date range
df_filtered = df_filtered[(df_filtered.index.date >= date_range[0]) & (df_filtered.index.date <= date_range[1])]
# Create a separate DataFrame for resampling and plotting
df_resampled = df_filtered[variable].resample(resample_dict[resample_time]).sum()
st.line_chart(df_resampled)
st.markdown("### Correlation Analysis")
# Dropdown to select variables for scatter plot
options = ['Followers at Posting', 'Total Interactions', 'Likes', 'Comments']
scatter_variable_1 = st.selectbox('Select Variable 1 for Scatter Plot', options)
# options.remove(scatter_variable_1) # remove the chosen option from the list
scatter_variable_2 = st.selectbox('Select Variable 2 for Scatter Plot', options)
# Plot scatter chart
st.write(f"Scatter Plot of {scatter_variable_1} vs {scatter_variable_2}")
# Plot scatter chart
scatter_fig = px.scatter(df_filtered, x=scatter_variable_1, y=scatter_variable_2) #, trendline='ols', trendline_color_override='red')
st.plotly_chart(scatter_fig)
# calculate correlation for scatter_variable_1 with scatter_variable_2
corr = df_filtered[scatter_variable_1].corr(df_filtered[scatter_variable_2])
if corr > 0.7:
st.write(f"The correlation coefficient is {corr}, indicating a strong positive relationship between {scatter_variable_1} and {scatter_variable_2}.")
elif corr > 0.3:
st.write(f"The correlation coefficient is {corr}, indicating a moderate positive relationship between {scatter_variable_1} and {scatter_variable_2}.")
elif corr > -0.3:
st.write(f"The correlation coefficient is {corr}, indicating a weak or no relationship between {scatter_variable_1} and {scatter_variable_2}.")
elif corr > -0.7:
st.write(f"The correlation coefficient is {corr}, indicating a moderate negative relationship between {scatter_variable_1} and {scatter_variable_2}.")
else:
st.write(f"The correlation coefficient is {corr}, indicating a strong negative relationship between {scatter_variable_1} and {scatter_variable_2}.")