Spaces:
Sleeping
Sleeping
import pandas as pd | |
import streamlit as st | |
import datasets | |
import plotly.express as px | |
from transformers import AutoProcessor, AutoModel | |
from PIL import Image | |
import os | |
from pandas.api.types import ( | |
is_categorical_dtype, | |
is_datetime64_any_dtype, | |
is_numeric_dtype, | |
is_object_dtype, | |
) | |
import subprocess | |
from tempfile import NamedTemporaryFile | |
from itertools import combinations | |
import networkx as nx | |
import plotly.graph_objects as go | |
import colorcet as cc | |
from matplotlib.colors import rgb2hex | |
from sklearn.cluster import KMeans, MiniBatchKMeans | |
from sklearn.decomposition import PCA | |
import hdbscan | |
import umap | |
import numpy as np | |
from bokeh.plotting import figure | |
from bokeh.models import ColumnDataSource | |
from datetime import datetime | |
import re | |
#st.set_page_config(layout="wide") | |
model_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" | |
token_ = st.secrets["token"] | |
def load_model(model_name): | |
""" | |
Load the model and processor | |
""" | |
processor = AutoProcessor.from_pretrained(model_name) | |
model = AutoModel.from_pretrained(model_name) | |
return processor, model | |
def load_dataset(): | |
dataset = datasets.load_dataset('rjadr/ditaduranuncamais', split='train', token=token_) | |
dataset.add_faiss_index(column="text_embs") | |
dataset.add_faiss_index(column="img_embs") | |
dataset = dataset.remove_columns(['Post Created Date', 'Post Created Time','Like and View Counts Disabled','Link','Download URL','Views']) | |
return dataset | |
def load_dataframe(_dataset): | |
dataframe = _dataset.remove_columns(['text_embs', 'img_embs']).to_pandas() | |
# Extract hashtags with regex and convert to set | |
dataframe['Hashtags'] = dataframe.apply(lambda row: f"{row['Description']} {row['Image Text']}", axis=1) | |
dataframe['Hashtags'] = dataframe['Hashtags'].str.lower().str.findall(r'#(\w+)').apply(set) | |
# Create a cleaned description column up-front | |
dataframe['description_clean'] = dataframe['Description'].apply(clean_and_truncate_text) | |
# Reorder columns to keep the new column next to the original | |
dataframe = dataframe[['Post Created', 'image', 'Description', 'description_clean', 'Image Text', 'Account', 'User Name'] + [col for col in dataframe.columns if col not in ['Post Created', 'image', 'Description', 'description_clean', 'Image Text', 'Account', 'User Name']]] | |
return dataframe | |
def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame: | |
""" | |
Adds a UI on top of a dataframe to let viewers filter columns | |
Args: | |
df (pd.DataFrame): Original dataframe | |
Returns: | |
pd.DataFrame: Filtered dataframe | |
""" | |
modify = st.checkbox("Add filters") | |
if not modify: | |
return df | |
df = df.copy() | |
# Try to convert datetimes into a standard format (datetime, no timezone) | |
for col in df.columns: | |
if is_object_dtype(df[col]): | |
try: | |
df[col] = pd.to_datetime(df[col]) | |
except Exception: | |
pass | |
if is_datetime64_any_dtype(df[col]): | |
df[col] = df[col].dt.tz_localize(None) | |
modification_container = st.container() | |
with modification_container: | |
to_filter_columns = st.multiselect("Filter dataframe on", df.columns) | |
for column in to_filter_columns: | |
left, right = st.columns((1, 20)) | |
left.write("↳") | |
# Treat columns with < 10 unique values as categorical | |
if is_categorical_dtype(df[column]) or df[column].nunique() < 10: | |
user_cat_input = right.multiselect( | |
f"Values for {column}", | |
df[column].unique(), | |
default=list(df[column].unique()), | |
) | |
df = df[df[column].isin(user_cat_input)] | |
elif is_numeric_dtype(df[column]): | |
_min = float(df[column].min()) | |
_max = float(df[column].max()) | |
step = (_max - _min) / 100 | |
user_num_input = right.slider( | |
f"Values for {column}", | |
_min, | |
_max, | |
(_min, _max), | |
step=step, | |
) | |
df = df[df[column].between(*user_num_input)] | |
elif is_datetime64_any_dtype(df[column]): | |
user_date_input = right.date_input( | |
f"Values for {column}", | |
value=( | |
df[column].min(), | |
df[column].max(), | |
), | |
) | |
if len(user_date_input) == 2: | |
user_date_input = tuple(map(pd.to_datetime, user_date_input)) | |
start_date, end_date = user_date_input | |
df = df.loc[df[column].between(start_date, end_date)] | |
else: | |
user_text_input = right.text_input( | |
f"Substring or regex in {column}", | |
) | |
if user_text_input: | |
df = df[df[column].str.contains(user_text_input)] | |
return df | |
def get_image_embs(_processor, _model, uploaded_file): | |
""" | |
Get image embeddings | |
Parameters: | |
processor (transformers.AutoProcessor): Processor for the model | |
model (transformers.AutoModel): Model to use for embeddings | |
uploaded_file (PIL.Image): Uploaded image file | |
Returns: | |
img_emb (np.array): Image embeddings | |
""" | |
# Load the image from local path | |
image = Image.open(uploaded_file) | |
# Process the image | |
inputs = _processor(images=image, return_tensors="pt") | |
# Forward pass without gradient calculation | |
outputs = _model.get_image_features(**inputs) | |
# Normalize the image embeddings | |
img_embs = outputs / outputs.norm(dim=-1, keepdim=True) | |
# Convert to list and add to example | |
img_emb = img_embs.squeeze(0).detach().cpu().numpy() | |
return img_emb | |
def get_text_embs(_processor, _model, text): | |
""" | |
Get text embeddings | |
Parameters: | |
processor (transformers.AutoProcessor): Processor for the model | |
model (transformers.AutoModel): Model to use for embeddings | |
text (str): Text to encode | |
Returns: | |
text_emb (np.array): Text embeddings | |
""" | |
# Process the text with truncation | |
inputs = _processor( | |
text=text, | |
return_tensors="pt", | |
padding="max_length", | |
truncation=True, | |
max_length=77 # CLIP's maximum sequence length | |
) | |
# Forward pass without gradient calculation | |
outputs = _model.get_text_features(**inputs) | |
# Normalize the text embeddings | |
text_embs = outputs / outputs.norm(dim=-1, keepdim=True) | |
# Convert to list and add to example | |
txt_emb = text_embs.squeeze(0).detach().cpu().numpy() | |
return txt_emb | |
def postprocess_results(scores, samples): | |
""" | |
Postprocess results to tuple of labels and scores | |
Parameters: | |
scores (np.array): Scores | |
samples (datasets.Dataset): Samples | |
Returns: | |
labels (list): List of tuples of PIL images and labels/scores | |
""" | |
samples_df = pd.DataFrame.from_dict(samples) | |
samples_df["score"] = scores | |
samples_df["score"] = (1 - (samples_df["score"] - samples_df["score"].min()) / ( | |
samples_df["score"].max() - samples_df["score"].min())) * 100 | |
samples_df["score"] = samples_df["score"].astype(int) | |
samples_df.reset_index(inplace=True, drop=True) | |
samples_df = samples_df[['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name'] + [col for col in samples_df.columns if col not in ['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name']]] | |
return samples_df.drop(columns=['text_embs', 'img_embs']) | |
def text_to_text(text, k=5): | |
""" | |
Text to text | |
Parameters: | |
text (str): Input text | |
k (int): Number of top results to return | |
Returns: | |
results (list): List of tuples of PIL images and labels/scores | |
""" | |
text_emb = get_text_embs(processor, model, text) | |
scores, samples = dataset.get_nearest_examples('text_embs', text_emb, k=k) | |
return postprocess_results(scores, samples) | |
def image_to_text(image, k=5): | |
""" | |
Image to text | |
Parameters: | |
image (str): Temp filepath to image | |
k (int): Number of top results to return | |
Returns: | |
results (list): List of tuples of PIL images and labels/scores | |
""" | |
img_emb = get_image_embs(processor, model, image.name) | |
scores, samples = dataset.get_nearest_examples('text_embs', img_emb, k=k) | |
return postprocess_results(scores, samples) | |
def text_to_image(text, k=5): | |
""" | |
Text to image | |
Parameters: | |
text (str): Input text | |
k (int): Number of top results to return | |
Returns: | |
results (list): List of tuples of PIL images and labels/scores | |
""" | |
text_emb = get_text_embs(processor, model, text) | |
scores, samples = dataset.get_nearest_examples('img_embs', text_emb, k=k) | |
return postprocess_results(scores, samples) | |
def image_to_image(image, k=5): | |
""" | |
Image to image | |
Parameters: | |
image (str): Temp filepath to image | |
k (int): Number of top results to return | |
Returns: | |
results (list): List of tuples of PIL images and labels/scores | |
""" | |
img_emb = get_image_embs(processor, model, image.name) | |
scores, samples = dataset.get_nearest_examples('img_embs', img_emb, k=k) | |
return postprocess_results(scores, samples) | |
def disparity_filter(g: nx.Graph, weight: str = 'weight', alpha: float = 0.05) -> nx.Graph: | |
""" | |
Computes the backbone of the input graph using the disparity filter algorithm. | |
The algorithm is proposed in: | |
M. A. Serrano, M. Boguna, and A. Vespignani, | |
"Extracting the Multiscale Backbone of Complex Weighted Networks", | |
PNAS, 106(16), pp 6483--6488 (2009). | |
DOI: 10.1073/pnas.0808904106 | |
Implementation taken from https://groups.google.com/g/networkx-discuss/c/bCuHZ3qQ2po/m/QvUUJqOYDbIJ | |
Parameters | |
---------- | |
g : NetworkX graph | |
The input graph. | |
weight : str, optional (default='weight') | |
The name of the edge attribute to use as weight. | |
alpha : float, optional (default=0.05) | |
The statistical significance level for the disparity filter (p-value). | |
Returns | |
------- | |
backbone_graph : NetworkX graph | |
The backbone graph. | |
""" | |
# Create an empty graph for the backbone | |
backbone_graph = nx.Graph() | |
# Iterate over all nodes in the input graph | |
for node in g: | |
# Get the degree of the node (number of edges connected to the node) | |
k_n = len(g[node]) | |
# Only proceed if the node has more than one connection | |
if k_n > 1: | |
# Calculate the sum of weights of edges connected to the node | |
sum_w = sum(g[node][neighbor][weight] for neighbor in g[node]) | |
# Iterate over all neighbors of the node | |
for neighbor in g[node]: | |
# Get the weight of the edge between the node and its neighbor | |
edge_weight = g[node][neighbor][weight] | |
# Calculate the proportion of the total weight that this edge represents | |
pij = float(edge_weight) / sum_w | |
# Perform the disparity filter test. If it passes, the edge is considered significant and is added to the backbone | |
if (1 - pij) ** (k_n - 1) < alpha: | |
backbone_graph.add_edge(node, neighbor, weight=edge_weight) | |
# Return the backbone graph | |
return backbone_graph | |
st.cache_data(show_spinner=True) | |
def assign_community_colors(G: nx.Graph, attr: str = 'community') -> dict: | |
""" | |
Assigns a unique color to each community in the input graph. | |
Parameters | |
---------- | |
G : nx.Graph | |
The input graph. | |
attr : str, optional | |
The node attribute of the community names or indexes (default is 'community'). | |
Returns | |
------- | |
dict | |
A dictionary mapping each community to a unique color. | |
""" | |
glasbey_colors = cc.glasbey_hv | |
communities_ = set(nx.get_node_attributes(G, attr).values()) | |
return {community: rgb2hex(glasbey_colors[i % len(glasbey_colors)]) for i, community in enumerate(communities_)} | |
st.cache_data(show_spinner=True) | |
def generate_hover_text(G: nx.Graph, attr: str = 'community') -> list: | |
""" | |
Generates hover text for each node in the input graph. | |
Parameters | |
---------- | |
G : nx.Graph | |
The input graph. | |
attr : str, optional | |
The node attribute of the community names or indexes (default is 'community'). | |
Returns | |
------- | |
list | |
A list of strings containing the hover text for each node. | |
""" | |
return [f"Node: {str(node)}<br>Community: {G.nodes[node][attr] + 1}<br># of connections: {len(adjacencies)}" for node, adjacencies in G.adjacency()] | |
st.cache_data(show_spinner=True) | |
def calculate_node_sizes(G: nx.Graph) -> list: | |
""" | |
Calculates the size of each node in the input graph based on its degree. | |
Parameters | |
---------- | |
G : nx.Graph | |
The input graph. | |
Returns | |
------- | |
list | |
A list of node sizes. | |
""" | |
degrees = dict(G.degree()) | |
max_degree = max(deg for node, deg in degrees.items()) | |
return [10 + 20 * (degrees[node] / max_degree) for node in G.nodes()] | |
def plot_graph(_G: nx.Graph, layout_name: str = "spring", community_names_lookup: dict = None): | |
""" | |
Plots a network graph with communities and a legend, using a choice of pure-Python layouts. | |
Parameters | |
---------- | |
_G : nx.Graph | |
The input graph with a 'community' attribute on each node. | |
layout_name : str, optional | |
The name of the NetworkX layout algorithm to use. | |
community_names_lookup : dict, optional | |
A dictionary mapping community key (e.g., 'Community 1') to a display name. | |
""" | |
# --- Select the layout algorithm --- | |
if layout_name == "kamada_kawai": | |
# Aesthetically pleasing, can be slow on large graphs. | |
pos = nx.kamada_kawai_layout(_G, dim=3) | |
elif layout_name == "circular": | |
# Fast, simple circle. It's 2D, so we add a Z-coordinate. | |
pos_2d = nx.circular_layout(_G) | |
pos = {node: (*coords, 0) for node, coords in pos_2d.items()} | |
elif layout_name == "spectral": | |
# Good for showing clusters. Also 2D, so we add a Z-coordinate. | |
pos_2d = nx.spectral_layout(_G) | |
pos = {node: (*coords, 0) for node, coords in pos_2d.items()} | |
else: # Default to "spring" | |
# The standard physics-based layout. | |
pos = nx.spring_layout(_G, dim=3, k=0.15, iterations=50, seed=779) | |
# --- Generate colors and traces (this part remains the same) --- | |
communities = sorted(list(set(nx.get_node_attributes(_G, 'community').values()))) | |
community_colors = {comm: color for comm, color in zip(communities, cc.glasbey_hv)} | |
edge_x, edge_y, edge_z = [], [], [] | |
for edge in _G.edges(): | |
x0, y0, z0 = pos[edge[0]] | |
x1, y1, z1 = pos[edge[1]] | |
edge_x.extend([x0, x1, None]) | |
edge_y.extend([y0, y1, None]) | |
edge_z.extend([z0, z1, None]) | |
edge_trace = go.Scatter3d( | |
x=edge_x, y=edge_y, z=edge_z, | |
line=dict(width=0.5, color='#888'), | |
hoverinfo='none', | |
mode='lines') | |
data = [edge_trace] | |
for comm_idx in communities: | |
comm_key = f'Community {comm_idx + 1}' | |
comm_name = community_names_lookup.get(comm_key, comm_key) | |
node_x, node_y, node_z, node_text = [], [], [], [] | |
for node in _G.nodes(): | |
if _G.nodes[node]['community'] == comm_idx: | |
x, y, z = pos[node] | |
node_x.append(x) | |
node_y.append(y) | |
node_z.append(z) | |
node_text.append(f"Hashtag: #{node}<br>Community: {comm_name}") | |
node_trace = go.Scatter3d( | |
x=node_x, y=node_y, z=node_z, | |
mode='markers', | |
name=comm_name, | |
marker=dict( | |
symbol='circle', | |
size=7, | |
color=rgb2hex(community_colors[comm_idx]), | |
line=dict(color='rgb(50,50,50)', width=0.5) | |
), | |
text=node_text, | |
hoverinfo='text' | |
) | |
data.append(node_trace) | |
# --- Layout (remains the same) --- | |
layout = go.Layout( | |
title="3D Hashtag Network Graph", | |
showlegend=True, | |
legend=dict(title="Communities", x=1.05, y=0.5), | |
width=1000, | |
height=800, | |
margin=dict(l=0, r=0, b=0, t=40), | |
scene=dict( | |
xaxis=dict(showbackground=False, showline=False, zeroline=False, showgrid=False, showticklabels=False, title=''), | |
yaxis=dict(showbackground=False, showline=False, zeroline=False, showgrid=False, showticklabels=False, title=''), | |
zaxis=dict(showbackground=False, showline=False, zeroline=False, showgrid=False, showticklabels=False, title='') | |
) | |
) | |
fig = go.Figure(data=data, layout=layout) | |
return fig | |
def clean_and_truncate_text(text, max_length=30): | |
""" | |
Removes hashtags and truncates text to a specified length. | |
Args: | |
text (str): The input string to clean. | |
max_length (int): The maximum length of the output string. | |
Returns: | |
str: The cleaned and truncated string. | |
""" | |
if not isinstance(text, str): | |
return "" # Return empty string for non-string inputs | |
# Use regex to remove hashtags (words starting with #) | |
no_hashtags = re.sub(r'#\w+\s*', '', text).strip() | |
# Truncate the string if it's too long | |
if len(no_hashtags) > max_length: | |
return no_hashtags[:max_length] + '...' | |
else: | |
return no_hashtags | |
def cluster_embeddings(embeddings, clustering_algo='KMeans', dim_reduction='PCA', | |
# KMeans & MiniBatchKMeans params | |
n_clusters=5, batch_size=256, max_iter=100, | |
# HDBSCAN params | |
min_cluster_size=5, min_samples=5, | |
# Reducer params | |
n_components=2, n_neighbors=15, min_dist=0.0, random_state=42): | |
"""Performs dimensionality reduction and clustering on a set of embeddings. | |
This function chains two steps: first, it reduces the dimensionality of the | |
input embeddings using either PCA or UMAP. Second, it applies a clustering | |
algorithm (KMeans, MiniBatchKMeans, or HDBSCAN) to the reduced-dimensional | |
data to assign a cluster label to each embedding. | |
Args: | |
embeddings (list or np.ndarray): A list or array of high-dimensional | |
embedding vectors. Each element should be a 1D NumPy array. | |
clustering_algo (str, optional): The clustering algorithm to use. | |
Options are 'KMeans', 'MiniBatchKMeans', or 'HDBSCAN'. | |
Defaults to 'KMeans'. | |
dim_reduction (str, optional): The dimensionality reduction method to use. | |
Options are 'PCA' or 'UMAP'. Defaults to 'PCA'. | |
n_clusters (int, optional): The number of clusters to form. Used by | |
KMeans and MiniBatchKMeans. Defaults to 5. | |
batch_size (int, optional): The size of mini-batches for MiniBatchKMeans. | |
Defaults to 256. | |
max_iter (int, optional): The maximum number of iterations for | |
MiniBatchKMeans. Defaults to 100. | |
min_cluster_size (int, optional): The minimum number of samples in a | |
group for it to be considered a cluster. Used by HDBSCAN. | |
Defaults to 5. | |
min_samples (int, optional): The number of samples in a neighborhood for | |
a point to be considered a core point. Used by HDBSCAN. | |
Defaults to 5. | |
n_components (int, optional): The number of dimensions to reduce to. | |
Used by PCA and UMAP. Defaults to 2. | |
n_neighbors (int, optional): The number of neighbors to consider for | |
manifold approximation. Used by UMAP. Defaults to 15. | |
min_dist (float, optional): The effective minimum distance between | |
embedded points. Used by UMAP. Defaults to 0.0. | |
random_state (int, optional): The seed used by the random number | |
generator for reproducibility. Defaults to 42. | |
Returns: | |
tuple: A tuple containing: | |
- np.ndarray: An array of cluster labels assigned to each embedding. | |
- np.ndarray: The reduced-dimensional representation of the embeddings. | |
Raises: | |
ValueError: If an invalid `clustering_algo` or `dim_reduction` method | |
is specified. | |
""" | |
# Stack embeddings into a single NumPy array | |
data_array = np.stack(embeddings) | |
# --- 1. Dimensionality Reduction --- | |
if dim_reduction == 'PCA': | |
reducer = PCA(n_components=n_components, random_state=random_state) | |
elif dim_reduction == 'UMAP': | |
reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components, random_state=random_state) | |
else: | |
raise ValueError('Invalid dimensionality reduction method') | |
reduced_embeddings = reducer.fit_transform(data_array) | |
# --- 2. Clustering --- | |
if clustering_algo == 'MiniBatchKMeans': | |
# Use the specific parameters for MiniBatchKMeans | |
clusterer = MiniBatchKMeans( | |
n_clusters=n_clusters, | |
random_state=random_state, | |
batch_size=batch_size, | |
max_iter=max_iter, | |
n_init='auto' # Recommended setting | |
) | |
elif clustering_algo == 'KMeans': | |
clusterer = KMeans(n_clusters=n_clusters, random_state=random_state, n_init='auto') | |
elif clustering_algo == 'HDBSCAN': | |
clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples) | |
else: | |
raise ValueError('Invalid clustering algorithm') | |
labels = clusterer.fit_predict(reduced_embeddings) | |
return labels, reduced_embeddings | |
st.title("#ditaduranuncamais Data Explorer") | |
def check_password(): | |
"""Returns `True` if user is authenticated, `False` otherwise.""" | |
# If the user is already authenticated, just return True. | |
# This is the most important part: we don't render the password form again. | |
if st.session_state.get("password_correct", False): | |
return True | |
# This part of the code will only run if the user is not yet authenticated. | |
def password_entered(): | |
"""Checks whether the password entered is correct.""" | |
if st.session_state.get("password") == st.secrets.get("password"): | |
st.session_state["password_correct"] = True | |
# Don't store the password in session state. | |
del st.session_state["password"] | |
else: | |
st.session_state["password_correct"] = False | |
# Show the password input form. | |
st.text_input( | |
"Password", type="password", on_change=password_entered, key="password" | |
) | |
# Show an error message if the last attempt was incorrect. | |
# The 'in' check prevents the error from showing on the first load. | |
if "password_correct" in st.session_state and not st.session_state.password_correct: | |
st.error("😕 Password incorrect") | |
# Return False to stop the main app from running. | |
return False | |
if not check_password(): | |
st.stop() | |
# Check if the directory exists | |
dataset = load_dataset() | |
df = load_dataframe(dataset) | |
processor, model = load_model(model_name) | |
#image_model = load_img_model() | |
#text_model = load_txt_model() | |
menu_options = ["Data exploration", "Semantic search", "Hashtags", "Clustering", "Stats"] | |
st.sidebar.markdown('# Menu') | |
selected_menu_option = st.sidebar.radio("Select a page", menu_options) | |
if selected_menu_option == "Data exploration": | |
st.dataframe( | |
data=filter_dataframe(df), | |
# use_container_width=True, | |
column_config={ | |
"image": st.column_config.ImageColumn( | |
"Image", help="Instagram image" | |
), | |
"URL": st.column_config.LinkColumn( | |
"Link", help="Instagram link", width="small" | |
) | |
}, | |
hide_index=True, | |
) | |
elif selected_menu_option == "Semantic search": | |
tabs = ["Text to Text", "Text to Image", "Image to Image", "Image to Text"] | |
selected_tab = st.sidebar.radio("Select a search type", tabs) | |
if selected_tab == "Text to Text": | |
st.markdown('## Text to text search') | |
text_to_text_input = st.text_input("Enter text") | |
text_to_text_k_top = st.slider("Number of results", 1, 500, 20) | |
if st.button("Search"): | |
if not text_to_text_input: | |
st.warning("Please enter text") | |
else: | |
st.dataframe( | |
data=text_to_text(text_to_text_input, text_to_text_k_top), | |
column_config={ | |
"image": st.column_config.ImageColumn( | |
"Image", help="Instagram image" | |
), | |
"URL": st.column_config.LinkColumn( | |
"Link", help="Instagram link", width="small" | |
) | |
}, | |
hide_index=True, | |
) | |
elif selected_tab == "Text to Image": | |
st.markdown('## Text to image search') | |
text_to_image_input = st.text_input("Enter text") | |
text_to_image_k_top = st.slider("Number of results", 1, 500, 20) | |
if st.button("Search"): | |
if not text_to_image_input: | |
st.warning("Please enter some text") | |
else: | |
st.dataframe( | |
data=text_to_image(text_to_image_input, text_to_image_k_top), | |
column_config={ | |
"image": st.column_config.ImageColumn( | |
"Image", help="Instagram image" | |
), | |
"URL": st.column_config.LinkColumn( | |
"Link", help="Instagram link", width="small" | |
) | |
}, | |
hide_index=True, | |
) | |
elif selected_tab == "Image to Image": | |
st.markdown('## Image to image search') | |
image_to_image_k_top = st.slider("Number of results", 1, 500, 20) | |
image_to_image_input = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
temp_file = NamedTemporaryFile(delete=False) | |
if st.button("Search"): | |
if not image_to_image_input: | |
st.warning("Please upload an image") | |
else: | |
temp_file.write(image_to_image_input.getvalue()) | |
st.dataframe( | |
data=image_to_image(temp_file, image_to_image_k_top), | |
column_config={ | |
"image": st.column_config.ImageColumn( | |
"Image", help="Instagram image" | |
), | |
"URL": st.column_config.LinkColumn( | |
"Link", help="Instagram link", width="small" | |
) | |
}, | |
hide_index=True, | |
) | |
elif selected_tab == "Image to Text": | |
st.markdown('## Image to text search') | |
image_to_text_k_top = st.slider("Number of results", 1, 500, 20) | |
image_to_text_input = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
temp_file = NamedTemporaryFile(delete=False) | |
if st.button("Search"): | |
if not image_to_text_input: | |
st.warning("Please upload an image") | |
else: | |
temp_file.write(image_to_text_input.getvalue()) | |
st.dataframe( | |
data=image_to_text(temp_file, image_to_text_k_top), | |
column_config={ | |
"image": st.column_config.ImageColumn( | |
"Image", help="Instagram image" | |
), | |
"URL": st.column_config.LinkColumn( | |
"Link", help="Instagram link", width="small" | |
) | |
}, | |
hide_index=True, | |
) | |
elif selected_menu_option == "Hashtags": | |
st.markdown("### Hashtag Co-occurrence Analysis") | |
st.markdown("This section creates a network of hashtags based on how often they are used together. Use the sidebar to configure the analysis, then click the button to generate the network and identify communities.") | |
# --- Sidebar Configuration (no changes) --- | |
if 'dfx' not in st.session_state: | |
st.session_state.dfx = df.copy() | |
all_hashtags = sorted(list(set(item for sublist in st.session_state.dfx['Hashtags'] for item in sublist))) | |
st.sidebar.markdown('## Hashtag Network Options') | |
hashtags_to_remove = st.sidebar.multiselect("Hashtags to remove", all_hashtags) | |
col1, col2 = st.sidebar.columns(2) | |
if col1.button("Remove hashtags"): | |
st.session_state.dfx['Hashtags'] = st.session_state.dfx['Hashtags'].apply(lambda x: [item for item in x if item not in hashtags_to_remove]) | |
if 'hashtag_results' in st.session_state: | |
del st.session_state.hashtag_results | |
st.rerun() | |
if col2.button("Reset Hashtags"): | |
st.session_state.dfx = df.copy() | |
if 'hashtag_results' in st.session_state: | |
del st.session_state.hashtag_results | |
st.rerun() | |
weight_option = st.sidebar.radio( | |
'Select weight definition', | |
('Number of users that use the hashtag pairs', 'Total number of occurrences') | |
) | |
# --- Main Page Content --- | |
if st.button("Generate Hashtag Network", type="primary"): | |
with st.spinner("Building graph, filtering edges, and detecting communities..."): | |
# (Calculation code remains the same as before...) | |
hashtag_user_pairs = [(tuple(sorted(combination)), userid) for hashtags, userid in zip(st.session_state.dfx['Hashtags'], st.session_state.dfx['User Name']) for combination in combinations(hashtags, r=2)] | |
hashtag_user_df = pd.DataFrame(hashtag_user_pairs, columns=['hashtag_pair', 'User Name']) | |
if weight_option == 'Number of users that use the hashtag pairs': | |
edge_df = hashtag_user_df.groupby('hashtag_pair').agg({'User Name': 'nunique'}).reset_index() | |
else: | |
edge_df = hashtag_user_df.groupby('hashtag_pair').size().reset_index(name='User Name') | |
edge_df = edge_df.rename(columns={'User Name': 'weight'}) | |
edge_df[['hashtag1', 'hashtag2']] = pd.DataFrame(edge_df['hashtag_pair'].tolist(), index=edge_df.index) | |
edge_list = edge_df[['hashtag1', 'hashtag2', 'weight']] | |
G = nx.from_pandas_edgelist(edge_list, 'hashtag1', 'hashtag2', 'weight') | |
G_backbone = disparity_filter(G, weight='weight', alpha=0.05) | |
communities = list(nx.community.louvain_communities(G_backbone, weight='weight', seed=1234)) | |
communities.sort(key=len, reverse=True) | |
for i, community in enumerate(communities): | |
for node in community: | |
G_backbone.nodes[node]['community'] = i | |
sorted_community_hashtags = pd.DataFrame([ | |
[h for h, _ in sorted(((h, G.degree(h, weight='weight')) for h in com), key=lambda x: x[1], reverse=True)] | |
for com in communities | |
]).T | |
sorted_community_hashtags.columns = [f'Community {i+1}' for i in range(len(sorted_community_hashtags.columns))] | |
# Initialize the community names dataframe and store it in session state | |
df_community_names = pd.DataFrame( | |
sorted_community_hashtags.columns, | |
columns=['community_names'], | |
index=sorted_community_hashtags.columns | |
) | |
st.session_state.community_names_df = df_community_names | |
st.session_state.hashtag_results = { | |
"G_backbone": G_backbone, | |
"communities": communities, | |
"sorted_community_hashtags": sorted_community_hashtags, | |
"edge_list": edge_list | |
} | |
st.rerun() | |
# --- Display Results Section --- | |
if 'hashtag_results' in st.session_state: | |
results = st.session_state.hashtag_results | |
G_backbone = results['G_backbone'] | |
communities = results['communities'] | |
sorted_community_hashtags = results['sorted_community_hashtags'] | |
edge_list = results['edge_list'] | |
st.success(f"Network generated! Found **{len(communities)}** communities from **{len(G_backbone.nodes)}** hashtags and **{len(G_backbone.edges)}** connections.") | |
# Define the tabs with the editor in its own tab | |
tab_graph, tab_editor, tab_timeline, tab_lists = st.tabs([ | |
"📊 Network Graph", | |
"📝 Edit Community Names", | |
"🕒 Community Timelines", | |
"📋 Community & Edge Lists" | |
]) | |
with tab_graph: | |
st.markdown("### Hashtag Network Graph") | |
st.markdown("Nodes represent hashtags, colored by community. The legend uses the names from the 'Edit Community Names' tab.") | |
# Re-introduce the layout selector with safe, pure-Python options | |
layout_options = { | |
"Spring": "spring", | |
"Kamada-Kawai": "kamada_kawai", | |
"Circular": "circular", | |
"Spectral": "spectral" | |
} | |
selected_layout_name = st.selectbox( | |
"Graph Layout Algorithm", | |
options=layout_options.keys() | |
) | |
# Get the actual function name string | |
layout_alg_str = layout_options[selected_layout_name] | |
# Retrieve edited names from session state | |
community_names_lookup = st.session_state.community_names_df['community_names'].to_dict() | |
# Call the plot function with the chosen layout | |
fig = plot_graph( | |
_G=G_backbone, | |
layout_name=layout_alg_str, | |
community_names_lookup=community_names_lookup | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
with tab_editor: | |
st.markdown("### Edit Community Names") | |
st.markdown("Change the default community names in the table below. The new names will automatically update the graph legend and the timeline chart.") | |
# The data editor modifies the dataframe in session_state | |
edited_df = st.data_editor( | |
st.session_state.community_names_df, | |
use_container_width=True, | |
num_rows="dynamic" # Allows for adding/removing if needed, though less likely here | |
) | |
# Persist any changes back to session state | |
st.session_state.community_names_df = edited_df | |
st.download_button( | |
label="Download Community Names as CSV", | |
data=edited_df.to_csv().encode("utf-8"), | |
file_name="community_names.csv", | |
mime="text/csv", | |
) | |
with tab_timeline: | |
st.markdown("### Community Size Over Time") | |
# Retrieve the latest names from session state for the multiselect options | |
community_names_lookup = st.session_state.community_names_df['community_names'].to_dict() | |
selected_communities = st.multiselect('Select Communities', community_names_lookup.values(), default=list(community_names_lookup.values())) | |
resample_dict = {'Day': 'D', 'Week': 'W', 'Month': 'M', 'Quarter': 'Q', 'Year': 'Y'} | |
resample_time = st.selectbox('Select Time Resampling', list(resample_dict.keys()), index=4) | |
community_dict = {node: community_names_lookup.get(f'Community {i+1}') for i, comm_set in enumerate(communities) for node in comm_set} | |
df_communities = st.session_state.dfx.copy() | |
df_communities['Communities'] = df_communities['Hashtags'].apply(lambda tags: list(set(community_dict.get(tag) for tag in tags if tag in community_dict))) | |
df_communities = df_communities.explode('Communities').dropna(subset=['Communities']) | |
df_ts = df_communities.set_index('Post Created') | |
df_community_sizes = df_ts.groupby([pd.Grouper(freq=resample_dict[resample_time]), 'Communities']).size().unstack(fill_value=0) | |
existing_selected_cols = [col for col in selected_communities if col in df_community_sizes.columns] | |
if existing_selected_cols: | |
st.area_chart(df_community_sizes[existing_selected_cols]) | |
else: | |
st.warning("No data available for the selected communities.") | |
with tab_lists: | |
st.markdown("### Hashtag Communities (by importance)") | |
st.dataframe(sorted_community_hashtags) | |
st.markdown("### Top Edge Pairs (by weight)") | |
st.dataframe(edge_list.sort_values(by='weight', ascending=False).head(100)) | |
elif selected_menu_option == "Clustering": | |
st.markdown("## Clustering of Posts") | |
st.markdown("This section allows you to group posts based on the similarity of their text or image content. Use the sidebar to configure the clustering process, then click 'Run Clustering' to see the results.") | |
# --- Sidebar Configuration (no changes here) --- | |
st.sidebar.markdown("# Clustering Options") | |
st.sidebar.markdown("### Data & Algorithm") | |
type_embeddings = st.sidebar.selectbox("Cluster based on:", ["Image", "Text"]) | |
clustering_algo = st.sidebar.selectbox("Clustering Algorithm:", ["MiniBatchKMeans", "HDBSCAN", "KMeans"]) | |
st.sidebar.info(f"**Tip:** `MiniBatchKMeans` is the fastest for a quick overview.") | |
st.sidebar.markdown("### Algorithm Settings") | |
if clustering_algo in ["KMeans", "MiniBatchKMeans"]: | |
n_clusters = st.sidebar.slider("Number of Clusters (k)", 2, 50, 5, key="n_clusters_slider") | |
if clustering_algo == "MiniBatchKMeans": | |
batch_size = st.sidebar.slider("Batch Size", 32, 1024, 256, 32, help="Number of samples to use in each mini-batch.") | |
max_iter = st.sidebar.slider("Max Iterations", 50, 500, 100, 50, help="Maximum number of iterations.") | |
else: | |
batch_size, max_iter = None, None | |
min_cluster_size, min_samples = None, None | |
elif clustering_algo == "HDBSCAN": | |
min_cluster_size = st.sidebar.slider("Minimum Cluster Size", 2, 200, 15, help="Smallest size for a group to be a cluster.") | |
min_samples = st.sidebar.slider("Minimum Samples", 1, 50, 5, help="Larger values lead to more points being declared as noise.") | |
n_clusters, batch_size, max_iter = None, None, None | |
st.sidebar.markdown("### Dimensionality Reduction") | |
dim_reduction = st.sidebar.selectbox("Reduction Method:", ["PCA", "UMAP"]) | |
st.sidebar.info(f"**Tip:** `PCA` is much faster than `UMAP`.") | |
if dim_reduction == "UMAP": | |
n_components = st.sidebar.slider("Number of Components", 2, 80, 50, help="Dimensions to reduce to before clustering.") | |
n_neighbors = st.sidebar.slider("Number of Neighbors", 2, 50, 15, help="Controls UMAP's balance of local/global structure.") | |
min_dist = st.sidebar.slider("Minimum Distance", 0.0, 1.0, 0.0, help="Controls how tightly UMAP packs points.") | |
else: | |
n_components = st.sidebar.slider("Number of Components", 2, 80, 2) | |
n_neighbors, min_dist = None, None | |
# --- Main Page Content --- | |
# 1. Add a button to trigger the expensive computation | |
if st.button("Run Clustering", type="primary"): | |
with st.spinner("Clustering embeddings... This may take a moment."): | |
if type_embeddings == "Text": | |
embeddings = dataset['text_embs'] | |
else: # Image | |
embeddings = dataset['img_embs'] | |
# Call the expensive function here | |
labels, reduced_embeddings = cluster_embeddings( | |
embeddings, | |
clustering_algo=clustering_algo, | |
dim_reduction=dim_reduction, | |
n_clusters=n_clusters, | |
min_cluster_size=min_cluster_size, | |
n_components=n_components, | |
n_neighbors=n_neighbors, | |
min_dist=min_dist, | |
min_samples=min_samples, | |
batch_size=batch_size, | |
max_iter=max_iter | |
) | |
# 2. Store the results in session state | |
st.session_state['cluster_results'] = { | |
"labels": labels, | |
"reduced_embeddings": reduced_embeddings, | |
"type_embeddings": type_embeddings, | |
"clustering_algo": clustering_algo, | |
"dim_reduction": dim_reduction | |
} | |
st.rerun() # Rerun to display results immediately after calculation | |
# 3. Only show results if they exist in session state | |
if 'cluster_results' in st.session_state: | |
# Unpack results from session state | |
results = st.session_state['cluster_results'] | |
labels = results['labels'] | |
reduced_embeddings = results['reduced_embeddings'] | |
num_found_clusters = len(set(labels) - {-1}) | |
st.success(f"Clustering complete! Found **{num_found_clusters}** clusters using **{results['clustering_algo']}** on **{results['type_embeddings']}** embeddings with **{results['dim_reduction']}** reduction.") | |
df_clustered = df.copy() | |
df_clustered['cluster'] = labels | |
# 4. Use tabs to organize the output | |
tab1, tab2, tab3 = st.tabs(["📊 Results Table", "📈 2D Visualization", "🕒 Time Series Analysis"]) | |
with tab1: | |
st.markdown("### Clustered Data") | |
st.dataframe( | |
data=filter_dataframe(df_clustered), | |
column_config={ | |
"image": st.column_config.ImageColumn("Image", help="Instagram image"), | |
"URL": st.column_config.LinkColumn("Link", help="Instagram link", width="small") | |
}, | |
hide_index=True, | |
use_container_width=True | |
) | |
st.download_button( | |
"Download Clustered Data as CSV", | |
df_clustered.to_csv(index=False).encode('utf-8'), | |
f'clustered_data_{datetime.now().strftime("%Y%m%d-%H%M%S")}.csv', | |
"text/csv", | |
key='download-cluster-csv' | |
) | |
with tab2: | |
st.markdown("### Cluster Visualization") | |
if reduced_embeddings.shape[1] > 2: | |
with st.spinner("Reducing dimensions for 2D visualization..."): | |
vis_reducer = umap.UMAP(n_components=2, random_state=42) | |
vis_embeddings = vis_reducer.fit_transform(reduced_embeddings) | |
else: | |
vis_embeddings = reduced_embeddings | |
df_plot_bokeh = pd.DataFrame(vis_embeddings, columns=('x', 'y')) | |
df_plot_bokeh['description_clean'] = df_clustered['description_clean'] | |
df_plot_bokeh['image_url'] = df_clustered['image'] | |
df_plot_bokeh['cluster'] = labels | |
unique_labels = sorted(list(set(labels))) | |
color_dict = {label: rgb2hex(cc.glasbey_hv[i % len(cc.glasbey_hv)]) for i, label in enumerate(unique_labels)} | |
df_plot_bokeh['color'] = df_plot_bokeh['cluster'].map(color_dict) | |
source = ColumnDataSource(data=df_plot_bokeh) | |
TOOLTIPS = """ | |
<div style="width: 200px; padding: 5px; background-color: #f0f0f0; border-radius: 5px; font-family: sans-serif; border: 1px solid #cccccc;"> | |
<div> | |
<img src="@image_url" | |
height="150" | |
width="150" | |
style="display: block; margin: auto;" | |
border="0"> | |
</img> | |
</div> | |
<hr style="border: 1px solid #aaaaaa; margin: 8px 0;"> | |
<div style="text-align: left; padding: 0 5px;"> | |
<span style="font-size: 12px; font-weight: bold;">Cluster: @cluster</span><br> | |
<span style="font-size: 11px; word-wrap: break-word;">@description_clean</span> | |
</div> | |
</div> | |
""" | |
p = figure(width=800, height=800, tooltips=TOOLTIPS, title="2D Visualization of Post Clusters") | |
p.circle('x', 'y', size=10, source=source, color='color', legend_field='cluster', line_color=None, alpha=0.8) | |
p.legend.title = 'Cluster' | |
p.legend.location = "top_left" | |
st.bokeh_chart(p, use_container_width=True) | |
with tab3: | |
st.markdown("### Cluster Analysis Over Time") | |
# Define the dictionary before using it. | |
resample_dict = { | |
'Day': 'D', | |
'Week': 'W', | |
'Month': 'M', | |
'Quarter': 'Q', | |
'Year': 'Y' | |
} | |
variable = st.selectbox('Select Variable for Time Series:', ['Likes', 'Comments', 'Followers at Posting', 'Total Interactions'], key="cluster_ts_var") | |
resample_time = st.selectbox('Resample Time By:', list(resample_dict.keys()), index=2, key="cluster_ts_resample") | |
df_ts = df_clustered.copy() | |
df_ts['Post Created'] = pd.to_datetime(df_ts['Post Created']) | |
df_ts = df_ts.set_index('Post Created') | |
df_ts = df_ts[df_ts['cluster'] != -1] # Exclude noise points | |
if not df_ts.empty: | |
# Use the dictionary to get the correct frequency string ('D', 'W', 'M', etc.) | |
df_plot = df_ts.groupby([pd.Grouper(freq=resample_dict[resample_time]), 'cluster'])[variable].sum().unstack(fill_value=0) | |
st.line_chart(df_plot) | |
else: | |
st.warning("No data available for plotting (all points may have been classified as noise).") | |
elif selected_menu_option == "Stats": | |
st.markdown("### Time Series Analysis") | |
# Dropdown to select variables | |
variable = st.selectbox('Select Variable', ['Followers at Posting', 'Total Interactions', 'Likes', 'Comments']) | |
# Dropdown to select time resampling | |
resample_dict = { | |
'Day': 'D', | |
'Three Days': '3D', | |
'Week': 'W', | |
'Two Weeks': '2W', | |
'Month': 'M', | |
'Quarter': 'Q', | |
'Year': 'Y' | |
} | |
# Dropdown to select time resampling | |
resample_time = st.selectbox('Select Time Resampling', list(resample_dict.keys())) | |
df_filtered = df.set_index('Post Created') | |
# Slider for date range selection | |
min_date = df_filtered.index.min().date() | |
max_date = df_filtered.index.max().date() | |
date_range = st.slider('Select Date Range', min_value=min_date, max_value=max_date, value=(min_date, max_date)) | |
# Filter dataframe based on selected date range | |
df_filtered = df_filtered[(df_filtered.index.date >= date_range[0]) & (df_filtered.index.date <= date_range[1])] | |
# Create a separate DataFrame for resampling and plotting | |
df_resampled = df_filtered[variable].resample(resample_dict[resample_time]).sum() | |
st.line_chart(df_resampled) | |
st.markdown("### Correlation Analysis") | |
# Dropdown to select variables for scatter plot | |
options = ['Followers at Posting', 'Total Interactions', 'Likes', 'Comments'] | |
scatter_variable_1 = st.selectbox('Select Variable 1 for Scatter Plot', options) | |
# options.remove(scatter_variable_1) # remove the chosen option from the list | |
scatter_variable_2 = st.selectbox('Select Variable 2 for Scatter Plot', options) | |
# Plot scatter chart | |
st.write(f"Scatter Plot of {scatter_variable_1} vs {scatter_variable_2}") | |
# Plot scatter chart | |
scatter_fig = px.scatter(df_filtered, x=scatter_variable_1, y=scatter_variable_2) #, trendline='ols', trendline_color_override='red') | |
st.plotly_chart(scatter_fig) | |
# calculate correlation for scatter_variable_1 with scatter_variable_2 | |
corr = df_filtered[scatter_variable_1].corr(df_filtered[scatter_variable_2]) | |
if corr > 0.7: | |
st.write(f"The correlation coefficient is {corr}, indicating a strong positive relationship between {scatter_variable_1} and {scatter_variable_2}.") | |
elif corr > 0.3: | |
st.write(f"The correlation coefficient is {corr}, indicating a moderate positive relationship between {scatter_variable_1} and {scatter_variable_2}.") | |
elif corr > -0.3: | |
st.write(f"The correlation coefficient is {corr}, indicating a weak or no relationship between {scatter_variable_1} and {scatter_variable_2}.") | |
elif corr > -0.7: | |
st.write(f"The correlation coefficient is {corr}, indicating a moderate negative relationship between {scatter_variable_1} and {scatter_variable_2}.") | |
else: | |
st.write(f"The correlation coefficient is {corr}, indicating a strong negative relationship between {scatter_variable_1} and {scatter_variable_2}.") |