|
import pandas as pd |
|
import streamlit as st |
|
import datasets |
|
import plotly.express as px |
|
from transformers import AutoProcessor, AutoModel |
|
from PIL import Image |
|
import os |
|
from pandas.api.types import ( |
|
is_categorical_dtype, |
|
is_datetime64_any_dtype, |
|
is_numeric_dtype, |
|
is_object_dtype, |
|
) |
|
import subprocess |
|
from tempfile import NamedTemporaryFile |
|
from itertools import combinations |
|
import networkx as nx |
|
import plotly.graph_objects as go |
|
import colorcet as cc |
|
from matplotlib.colors import rgb2hex |
|
from sklearn.cluster import KMeans, MiniBatchKMeans |
|
from sklearn.decomposition import PCA |
|
import hdbscan |
|
import umap |
|
import numpy as np |
|
from bokeh.plotting import figure |
|
from bokeh.models import ColumnDataSource |
|
from datetime import datetime |
|
import re |
|
|
|
|
|
|
|
model_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" |
|
|
|
token_ = st.secrets["token"] |
|
|
|
@st.cache_resource(show_spinner=True) |
|
def load_model(model_name): |
|
""" |
|
Load the model and processor |
|
""" |
|
processor = AutoProcessor.from_pretrained(model_name) |
|
model = AutoModel.from_pretrained(model_name) |
|
return processor, model |
|
|
|
@st.cache_data(show_spinner=True) |
|
def load_dataset(): |
|
dataset = datasets.load_dataset('rjadr/ditaduranuncamais', split='train', token=token_) |
|
dataset.add_faiss_index(column="text_embs") |
|
dataset.add_faiss_index(column="img_embs") |
|
dataset = dataset.remove_columns(['Post Created Date', 'Post Created Time','Like and View Counts Disabled','Link','Download URL','Views']) |
|
return dataset |
|
|
|
@st.cache_data(show_spinner=False) |
|
def load_dataframe(_dataset): |
|
dataframe = _dataset.remove_columns(['text_embs', 'img_embs']).to_pandas() |
|
|
|
|
|
dataframe['Hashtags'] = dataframe.apply(lambda row: f"{row['Description']} {row['Image Text']}", axis=1) |
|
dataframe['Hashtags'] = dataframe['Hashtags'].str.lower().str.findall(r'#(\w+)').apply(set) |
|
|
|
|
|
dataframe['description_clean'] = dataframe['Description'].apply(clean_and_truncate_text) |
|
|
|
|
|
dataframe = dataframe[['Post Created', 'image', 'Description', 'description_clean', 'Image Text', 'Account', 'User Name'] + [col for col in dataframe.columns if col not in ['Post Created', 'image', 'Description', 'description_clean', 'Image Text', 'Account', 'User Name']]] |
|
return dataframe |
|
|
|
|
|
def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame: |
|
""" |
|
Adds a UI on top of a dataframe to let viewers filter columns |
|
Args: |
|
df (pd.DataFrame): Original dataframe |
|
Returns: |
|
pd.DataFrame: Filtered dataframe |
|
""" |
|
modify = st.checkbox("Add filters") |
|
|
|
if not modify: |
|
return df |
|
|
|
df = df.copy() |
|
|
|
|
|
for col in df.columns: |
|
if is_object_dtype(df[col]): |
|
try: |
|
df[col] = pd.to_datetime(df[col]) |
|
except Exception: |
|
pass |
|
|
|
if is_datetime64_any_dtype(df[col]): |
|
df[col] = df[col].dt.tz_localize(None) |
|
|
|
modification_container = st.container() |
|
|
|
with modification_container: |
|
to_filter_columns = st.multiselect("Filter dataframe on", df.columns) |
|
for column in to_filter_columns: |
|
left, right = st.columns((1, 20)) |
|
left.write("↳") |
|
|
|
if is_categorical_dtype(df[column]) or df[column].nunique() < 10: |
|
user_cat_input = right.multiselect( |
|
f"Values for {column}", |
|
df[column].unique(), |
|
default=list(df[column].unique()), |
|
) |
|
df = df[df[column].isin(user_cat_input)] |
|
elif is_numeric_dtype(df[column]): |
|
_min = float(df[column].min()) |
|
_max = float(df[column].max()) |
|
step = (_max - _min) / 100 |
|
user_num_input = right.slider( |
|
f"Values for {column}", |
|
_min, |
|
_max, |
|
(_min, _max), |
|
step=step, |
|
) |
|
df = df[df[column].between(*user_num_input)] |
|
elif is_datetime64_any_dtype(df[column]): |
|
user_date_input = right.date_input( |
|
f"Values for {column}", |
|
value=( |
|
df[column].min(), |
|
df[column].max(), |
|
), |
|
) |
|
if len(user_date_input) == 2: |
|
user_date_input = tuple(map(pd.to_datetime, user_date_input)) |
|
start_date, end_date = user_date_input |
|
df = df.loc[df[column].between(start_date, end_date)] |
|
else: |
|
user_text_input = right.text_input( |
|
f"Substring or regex in {column}", |
|
) |
|
if user_text_input: |
|
df = df[df[column].str.contains(user_text_input)] |
|
|
|
return df |
|
|
|
@st.cache_data |
|
def get_image_embs(_processor, _model, uploaded_file): |
|
""" |
|
Get image embeddings |
|
Parameters: |
|
processor (transformers.AutoProcessor): Processor for the model |
|
model (transformers.AutoModel): Model to use for embeddings |
|
uploaded_file (PIL.Image): Uploaded image file |
|
Returns: |
|
img_emb (np.array): Image embeddings |
|
""" |
|
|
|
image = Image.open(uploaded_file) |
|
|
|
|
|
inputs = _processor(images=image, return_tensors="pt") |
|
|
|
|
|
outputs = _model.get_image_features(**inputs) |
|
|
|
|
|
img_embs = outputs / outputs.norm(dim=-1, keepdim=True) |
|
|
|
|
|
img_emb = img_embs.squeeze(0).detach().cpu().numpy() |
|
|
|
return img_emb |
|
|
|
@st.cache_data(show_spinner=False) |
|
def get_text_embs(_processor, _model, text): |
|
""" |
|
Get text embeddings |
|
Parameters: |
|
processor (transformers.AutoProcessor): Processor for the model |
|
model (transformers.AutoModel): Model to use for embeddings |
|
text (str): Text to encode |
|
Returns: |
|
text_emb (np.array): Text embeddings |
|
""" |
|
|
|
inputs = _processor( |
|
text=text, |
|
return_tensors="pt", |
|
padding="max_length", |
|
truncation=True, |
|
max_length=77 |
|
) |
|
|
|
|
|
outputs = _model.get_text_features(**inputs) |
|
|
|
|
|
text_embs = outputs / outputs.norm(dim=-1, keepdim=True) |
|
|
|
|
|
txt_emb = text_embs.squeeze(0).detach().cpu().numpy() |
|
|
|
return txt_emb |
|
|
|
@st.cache_data |
|
def postprocess_results(scores, samples): |
|
""" |
|
Postprocess results to tuple of labels and scores |
|
Parameters: |
|
scores (np.array): Scores |
|
samples (datasets.Dataset): Samples |
|
Returns: |
|
labels (list): List of tuples of PIL images and labels/scores |
|
""" |
|
samples_df = pd.DataFrame.from_dict(samples) |
|
samples_df["score"] = scores |
|
samples_df["score"] = (1 - (samples_df["score"] - samples_df["score"].min()) / ( |
|
samples_df["score"].max() - samples_df["score"].min())) * 100 |
|
samples_df["score"] = samples_df["score"].astype(int) |
|
samples_df.reset_index(inplace=True, drop=True) |
|
samples_df = samples_df[['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name'] + [col for col in samples_df.columns if col not in ['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name']]] |
|
return samples_df.drop(columns=['text_embs', 'img_embs']) |
|
|
|
@st.cache_data |
|
def text_to_text(text, k=5): |
|
""" |
|
Text to text |
|
Parameters: |
|
text (str): Input text |
|
k (int): Number of top results to return |
|
Returns: |
|
results (list): List of tuples of PIL images and labels/scores |
|
""" |
|
text_emb = get_text_embs(processor, model, text) |
|
scores, samples = dataset.get_nearest_examples('text_embs', text_emb, k=k) |
|
return postprocess_results(scores, samples) |
|
|
|
@st.cache_data |
|
def image_to_text(image, k=5): |
|
""" |
|
Image to text |
|
Parameters: |
|
image (str): Temp filepath to image |
|
k (int): Number of top results to return |
|
Returns: |
|
results (list): List of tuples of PIL images and labels/scores |
|
""" |
|
img_emb = get_image_embs(processor, model, image.name) |
|
scores, samples = dataset.get_nearest_examples('text_embs', img_emb, k=k) |
|
return postprocess_results(scores, samples) |
|
|
|
@st.cache_data |
|
def text_to_image(text, k=5): |
|
""" |
|
Text to image |
|
Parameters: |
|
text (str): Input text |
|
k (int): Number of top results to return |
|
Returns: |
|
results (list): List of tuples of PIL images and labels/scores |
|
""" |
|
text_emb = get_text_embs(processor, model, text) |
|
scores, samples = dataset.get_nearest_examples('img_embs', text_emb, k=k) |
|
return postprocess_results(scores, samples) |
|
|
|
@st.cache_data |
|
def image_to_image(image, k=5): |
|
""" |
|
Image to image |
|
Parameters: |
|
image (str): Temp filepath to image |
|
k (int): Number of top results to return |
|
Returns: |
|
results (list): List of tuples of PIL images and labels/scores |
|
""" |
|
img_emb = get_image_embs(processor, model, image.name) |
|
scores, samples = dataset.get_nearest_examples('img_embs', img_emb, k=k) |
|
return postprocess_results(scores, samples) |
|
|
|
def disparity_filter(g: nx.Graph, weight: str = 'weight', alpha: float = 0.05) -> nx.Graph: |
|
""" |
|
Computes the backbone of the input graph using the disparity filter algorithm. |
|
The algorithm is proposed in: |
|
M. A. Serrano, M. Boguna, and A. Vespignani, |
|
"Extracting the Multiscale Backbone of Complex Weighted Networks", |
|
PNAS, 106(16), pp 6483--6488 (2009). |
|
DOI: 10.1073/pnas.0808904106 |
|
Implementation taken from https://groups.google.com/g/networkx-discuss/c/bCuHZ3qQ2po/m/QvUUJqOYDbIJ |
|
Parameters |
|
---------- |
|
g : NetworkX graph |
|
The input graph. |
|
weight : str, optional (default='weight') |
|
The name of the edge attribute to use as weight. |
|
alpha : float, optional (default=0.05) |
|
The statistical significance level for the disparity filter (p-value). |
|
Returns |
|
------- |
|
backbone_graph : NetworkX graph |
|
The backbone graph. |
|
""" |
|
|
|
backbone_graph = nx.Graph() |
|
|
|
|
|
for node in g: |
|
|
|
k_n = len(g[node]) |
|
|
|
|
|
if k_n > 1: |
|
|
|
sum_w = sum(g[node][neighbor][weight] for neighbor in g[node]) |
|
|
|
|
|
for neighbor in g[node]: |
|
|
|
edge_weight = g[node][neighbor][weight] |
|
|
|
|
|
pij = float(edge_weight) / sum_w |
|
|
|
|
|
if (1 - pij) ** (k_n - 1) < alpha: |
|
backbone_graph.add_edge(node, neighbor, weight=edge_weight) |
|
|
|
|
|
return backbone_graph |
|
|
|
st.cache_data(show_spinner=True) |
|
def assign_community_colors(G: nx.Graph, attr: str = 'community') -> dict: |
|
""" |
|
Assigns a unique color to each community in the input graph. |
|
Parameters |
|
---------- |
|
G : nx.Graph |
|
The input graph. |
|
attr : str, optional |
|
The node attribute of the community names or indexes (default is 'community'). |
|
Returns |
|
------- |
|
dict |
|
A dictionary mapping each community to a unique color. |
|
""" |
|
glasbey_colors = cc.glasbey_hv |
|
communities_ = set(nx.get_node_attributes(G, attr).values()) |
|
return {community: rgb2hex(glasbey_colors[i % len(glasbey_colors)]) for i, community in enumerate(communities_)} |
|
|
|
st.cache_data(show_spinner=True) |
|
def generate_hover_text(G: nx.Graph, attr: str = 'community') -> list: |
|
""" |
|
Generates hover text for each node in the input graph. |
|
Parameters |
|
---------- |
|
G : nx.Graph |
|
The input graph. |
|
attr : str, optional |
|
The node attribute of the community names or indexes (default is 'community'). |
|
Returns |
|
------- |
|
list |
|
A list of strings containing the hover text for each node. |
|
""" |
|
return [f"Node: {str(node)}<br>Community: {G.nodes[node][attr] + 1}<br># of connections: {len(adjacencies)}" for node, adjacencies in G.adjacency()] |
|
|
|
st.cache_data(show_spinner=True) |
|
def calculate_node_sizes(G: nx.Graph) -> list: |
|
""" |
|
Calculates the size of each node in the input graph based on its degree. |
|
Parameters |
|
---------- |
|
G : nx.Graph |
|
The input graph. |
|
Returns |
|
------- |
|
list |
|
A list of node sizes. |
|
""" |
|
degrees = dict(G.degree()) |
|
max_degree = max(deg for node, deg in degrees.items()) |
|
return [10 + 20 * (degrees[node] / max_degree) for node in G.nodes()] |
|
|
|
@st.cache_data(show_spinner=True) |
|
def plot_graph(_G: nx.Graph, layout_name: str = "spring", community_names_lookup: dict = None): |
|
""" |
|
Plots a network graph with communities and a legend, using a choice of pure-Python layouts. |
|
Parameters |
|
---------- |
|
_G : nx.Graph |
|
The input graph with a 'community' attribute on each node. |
|
layout_name : str, optional |
|
The name of the NetworkX layout algorithm to use. |
|
community_names_lookup : dict, optional |
|
A dictionary mapping community key (e.g., 'Community 1') to a display name. |
|
""" |
|
|
|
if layout_name == "kamada_kawai": |
|
|
|
pos = nx.kamada_kawai_layout(_G, dim=3) |
|
elif layout_name == "circular": |
|
|
|
pos_2d = nx.circular_layout(_G) |
|
pos = {node: (*coords, 0) for node, coords in pos_2d.items()} |
|
elif layout_name == "spectral": |
|
|
|
pos_2d = nx.spectral_layout(_G) |
|
pos = {node: (*coords, 0) for node, coords in pos_2d.items()} |
|
else: |
|
|
|
pos = nx.spring_layout(_G, dim=3, k=0.15, iterations=50, seed=779) |
|
|
|
|
|
communities = sorted(list(set(nx.get_node_attributes(_G, 'community').values()))) |
|
community_colors = {comm: color for comm, color in zip(communities, cc.glasbey_hv)} |
|
|
|
edge_x, edge_y, edge_z = [], [], [] |
|
for edge in _G.edges(): |
|
x0, y0, z0 = pos[edge[0]] |
|
x1, y1, z1 = pos[edge[1]] |
|
edge_x.extend([x0, x1, None]) |
|
edge_y.extend([y0, y1, None]) |
|
edge_z.extend([z0, z1, None]) |
|
|
|
edge_trace = go.Scatter3d( |
|
x=edge_x, y=edge_y, z=edge_z, |
|
line=dict(width=0.5, color='#888'), |
|
hoverinfo='none', |
|
mode='lines') |
|
|
|
data = [edge_trace] |
|
for comm_idx in communities: |
|
comm_key = f'Community {comm_idx + 1}' |
|
comm_name = community_names_lookup.get(comm_key, comm_key) |
|
|
|
node_x, node_y, node_z, node_text = [], [], [], [] |
|
for node in _G.nodes(): |
|
if _G.nodes[node]['community'] == comm_idx: |
|
x, y, z = pos[node] |
|
node_x.append(x) |
|
node_y.append(y) |
|
node_z.append(z) |
|
node_text.append(f"Hashtag: #{node}<br>Community: {comm_name}") |
|
|
|
node_trace = go.Scatter3d( |
|
x=node_x, y=node_y, z=node_z, |
|
mode='markers', |
|
name=comm_name, |
|
marker=dict( |
|
symbol='circle', |
|
size=7, |
|
color=rgb2hex(community_colors[comm_idx]), |
|
line=dict(color='rgb(50,50,50)', width=0.5) |
|
), |
|
text=node_text, |
|
hoverinfo='text' |
|
) |
|
data.append(node_trace) |
|
|
|
|
|
layout = go.Layout( |
|
title="3D Hashtag Network Graph", |
|
showlegend=True, |
|
legend=dict(title="Communities", x=1.05, y=0.5), |
|
width=1000, |
|
height=800, |
|
margin=dict(l=0, r=0, b=0, t=40), |
|
scene=dict( |
|
xaxis=dict(showbackground=False, showline=False, zeroline=False, showgrid=False, showticklabels=False, title=''), |
|
yaxis=dict(showbackground=False, showline=False, zeroline=False, showgrid=False, showticklabels=False, title=''), |
|
zaxis=dict(showbackground=False, showline=False, zeroline=False, showgrid=False, showticklabels=False, title='') |
|
) |
|
) |
|
|
|
fig = go.Figure(data=data, layout=layout) |
|
return fig |
|
|
|
def clean_and_truncate_text(text, max_length=30): |
|
""" |
|
Removes hashtags and truncates text to a specified length. |
|
|
|
Args: |
|
text (str): The input string to clean. |
|
max_length (int): The maximum length of the output string. |
|
|
|
Returns: |
|
str: The cleaned and truncated string. |
|
""" |
|
if not isinstance(text, str): |
|
return "" |
|
|
|
|
|
no_hashtags = re.sub(r'#\w+\s*', '', text).strip() |
|
|
|
|
|
if len(no_hashtags) > max_length: |
|
return no_hashtags[:max_length] + '...' |
|
else: |
|
return no_hashtags |
|
|
|
@st.cache_data(show_spinner=True) |
|
def cluster_embeddings(embeddings, clustering_algo='KMeans', dim_reduction='PCA', |
|
|
|
n_clusters=5, batch_size=256, max_iter=100, |
|
|
|
min_cluster_size=5, min_samples=5, |
|
|
|
n_components=2, n_neighbors=15, min_dist=0.0, random_state=42): |
|
"""Performs dimensionality reduction and clustering on a set of embeddings. |
|
|
|
This function chains two steps: first, it reduces the dimensionality of the |
|
input embeddings using either PCA or UMAP. Second, it applies a clustering |
|
algorithm (KMeans, MiniBatchKMeans, or HDBSCAN) to the reduced-dimensional |
|
data to assign a cluster label to each embedding. |
|
|
|
Args: |
|
embeddings (list or np.ndarray): A list or array of high-dimensional |
|
embedding vectors. Each element should be a 1D NumPy array. |
|
clustering_algo (str, optional): The clustering algorithm to use. |
|
Options are 'KMeans', 'MiniBatchKMeans', or 'HDBSCAN'. |
|
Defaults to 'KMeans'. |
|
dim_reduction (str, optional): The dimensionality reduction method to use. |
|
Options are 'PCA' or 'UMAP'. Defaults to 'PCA'. |
|
n_clusters (int, optional): The number of clusters to form. Used by |
|
KMeans and MiniBatchKMeans. Defaults to 5. |
|
batch_size (int, optional): The size of mini-batches for MiniBatchKMeans. |
|
Defaults to 256. |
|
max_iter (int, optional): The maximum number of iterations for |
|
MiniBatchKMeans. Defaults to 100. |
|
min_cluster_size (int, optional): The minimum number of samples in a |
|
group for it to be considered a cluster. Used by HDBSCAN. |
|
Defaults to 5. |
|
min_samples (int, optional): The number of samples in a neighborhood for |
|
a point to be considered a core point. Used by HDBSCAN. |
|
Defaults to 5. |
|
n_components (int, optional): The number of dimensions to reduce to. |
|
Used by PCA and UMAP. Defaults to 2. |
|
n_neighbors (int, optional): The number of neighbors to consider for |
|
manifold approximation. Used by UMAP. Defaults to 15. |
|
min_dist (float, optional): The effective minimum distance between |
|
embedded points. Used by UMAP. Defaults to 0.0. |
|
random_state (int, optional): The seed used by the random number |
|
generator for reproducibility. Defaults to 42. |
|
|
|
Returns: |
|
tuple: A tuple containing: |
|
- np.ndarray: An array of cluster labels assigned to each embedding. |
|
- np.ndarray: The reduced-dimensional representation of the embeddings. |
|
|
|
Raises: |
|
ValueError: If an invalid `clustering_algo` or `dim_reduction` method |
|
is specified. |
|
""" |
|
|
|
data_array = np.stack(embeddings) |
|
|
|
|
|
if dim_reduction == 'PCA': |
|
reducer = PCA(n_components=n_components, random_state=random_state) |
|
elif dim_reduction == 'UMAP': |
|
reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components, random_state=random_state) |
|
else: |
|
raise ValueError('Invalid dimensionality reduction method') |
|
|
|
reduced_embeddings = reducer.fit_transform(data_array) |
|
|
|
|
|
if clustering_algo == 'MiniBatchKMeans': |
|
|
|
clusterer = MiniBatchKMeans( |
|
n_clusters=n_clusters, |
|
random_state=random_state, |
|
batch_size=batch_size, |
|
max_iter=max_iter, |
|
n_init='auto' |
|
) |
|
elif clustering_algo == 'KMeans': |
|
clusterer = KMeans(n_clusters=n_clusters, random_state=random_state, n_init='auto') |
|
elif clustering_algo == 'HDBSCAN': |
|
clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples) |
|
else: |
|
raise ValueError('Invalid clustering algorithm') |
|
|
|
labels = clusterer.fit_predict(reduced_embeddings) |
|
|
|
return labels, reduced_embeddings |
|
|
|
st.title("#ditaduranuncamais Data Explorer") |
|
|
|
def check_password(): |
|
"""Returns `True` if user is authenticated, `False` otherwise.""" |
|
|
|
|
|
|
|
if st.session_state.get("password_correct", False): |
|
return True |
|
|
|
|
|
def password_entered(): |
|
"""Checks whether the password entered is correct.""" |
|
if st.session_state.get("password") == st.secrets.get("password"): |
|
st.session_state["password_correct"] = True |
|
|
|
del st.session_state["password"] |
|
else: |
|
st.session_state["password_correct"] = False |
|
|
|
|
|
st.text_input( |
|
"Password", type="password", on_change=password_entered, key="password" |
|
) |
|
|
|
|
|
|
|
if "password_correct" in st.session_state and not st.session_state.password_correct: |
|
st.error("😕 Password incorrect") |
|
|
|
|
|
return False |
|
|
|
if not check_password(): |
|
st.stop() |
|
|
|
|
|
|
|
dataset = load_dataset() |
|
df = load_dataframe(dataset) |
|
processor, model = load_model(model_name) |
|
|
|
|
|
|
|
menu_options = ["Data exploration", "Semantic search", "Hashtags", "Clustering", "Stats"] |
|
|
|
st.sidebar.markdown('# Menu') |
|
selected_menu_option = st.sidebar.radio("Select a page", menu_options) |
|
|
|
if selected_menu_option == "Data exploration": |
|
st.dataframe( |
|
data=filter_dataframe(df), |
|
|
|
column_config={ |
|
"image": st.column_config.ImageColumn( |
|
"Image", help="Instagram image" |
|
), |
|
"URL": st.column_config.LinkColumn( |
|
"Link", help="Instagram link", width="small" |
|
) |
|
}, |
|
hide_index=True, |
|
) |
|
|
|
elif selected_menu_option == "Semantic search": |
|
tabs = ["Text to Text", "Text to Image", "Image to Image", "Image to Text"] |
|
selected_tab = st.sidebar.radio("Select a search type", tabs) |
|
|
|
if selected_tab == "Text to Text": |
|
st.markdown('## Text to text search') |
|
text_to_text_input = st.text_input("Enter text") |
|
text_to_text_k_top = st.slider("Number of results", 1, 500, 20) |
|
if st.button("Search"): |
|
if not text_to_text_input: |
|
st.warning("Please enter text") |
|
else: |
|
st.dataframe( |
|
data=text_to_text(text_to_text_input, text_to_text_k_top), |
|
column_config={ |
|
"image": st.column_config.ImageColumn( |
|
"Image", help="Instagram image" |
|
), |
|
"URL": st.column_config.LinkColumn( |
|
"Link", help="Instagram link", width="small" |
|
) |
|
}, |
|
hide_index=True, |
|
) |
|
|
|
elif selected_tab == "Text to Image": |
|
st.markdown('## Text to image search') |
|
text_to_image_input = st.text_input("Enter text") |
|
text_to_image_k_top = st.slider("Number of results", 1, 500, 20) |
|
if st.button("Search"): |
|
if not text_to_image_input: |
|
st.warning("Please enter some text") |
|
else: |
|
st.dataframe( |
|
data=text_to_image(text_to_image_input, text_to_image_k_top), |
|
column_config={ |
|
"image": st.column_config.ImageColumn( |
|
"Image", help="Instagram image" |
|
), |
|
"URL": st.column_config.LinkColumn( |
|
"Link", help="Instagram link", width="small" |
|
) |
|
}, |
|
hide_index=True, |
|
) |
|
|
|
elif selected_tab == "Image to Image": |
|
st.markdown('## Image to image search') |
|
image_to_image_k_top = st.slider("Number of results", 1, 500, 20) |
|
image_to_image_input = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) |
|
temp_file = NamedTemporaryFile(delete=False) |
|
if st.button("Search"): |
|
if not image_to_image_input: |
|
st.warning("Please upload an image") |
|
else: |
|
temp_file.write(image_to_image_input.getvalue()) |
|
|
|
st.dataframe( |
|
data=image_to_image(temp_file, image_to_image_k_top), |
|
column_config={ |
|
"image": st.column_config.ImageColumn( |
|
"Image", help="Instagram image" |
|
), |
|
"URL": st.column_config.LinkColumn( |
|
"Link", help="Instagram link", width="small" |
|
) |
|
}, |
|
hide_index=True, |
|
) |
|
|
|
elif selected_tab == "Image to Text": |
|
st.markdown('## Image to text search') |
|
image_to_text_k_top = st.slider("Number of results", 1, 500, 20) |
|
image_to_text_input = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) |
|
temp_file = NamedTemporaryFile(delete=False) |
|
if st.button("Search"): |
|
if not image_to_text_input: |
|
st.warning("Please upload an image") |
|
else: |
|
temp_file.write(image_to_text_input.getvalue()) |
|
st.dataframe( |
|
data=image_to_text(temp_file, image_to_text_k_top), |
|
column_config={ |
|
"image": st.column_config.ImageColumn( |
|
"Image", help="Instagram image" |
|
), |
|
"URL": st.column_config.LinkColumn( |
|
"Link", help="Instagram link", width="small" |
|
) |
|
}, |
|
hide_index=True, |
|
) |
|
elif selected_menu_option == "Hashtags": |
|
st.markdown("### Hashtag Co-occurrence Analysis") |
|
st.markdown("This section creates a network of hashtags based on how often they are used together. Use the sidebar to configure the analysis, then click the button to generate the network and identify communities.") |
|
|
|
|
|
if 'dfx' not in st.session_state: |
|
st.session_state.dfx = df.copy() |
|
all_hashtags = sorted(list(set(item for sublist in st.session_state.dfx['Hashtags'] for item in sublist))) |
|
st.sidebar.markdown('## Hashtag Network Options') |
|
hashtags_to_remove = st.sidebar.multiselect("Hashtags to remove", all_hashtags) |
|
col1, col2 = st.sidebar.columns(2) |
|
if col1.button("Remove hashtags"): |
|
st.session_state.dfx['Hashtags'] = st.session_state.dfx['Hashtags'].apply(lambda x: [item for item in x if item not in hashtags_to_remove]) |
|
if 'hashtag_results' in st.session_state: |
|
del st.session_state.hashtag_results |
|
st.rerun() |
|
if col2.button("Reset Hashtags"): |
|
st.session_state.dfx = df.copy() |
|
if 'hashtag_results' in st.session_state: |
|
del st.session_state.hashtag_results |
|
st.rerun() |
|
weight_option = st.sidebar.radio( |
|
'Select weight definition', |
|
('Number of users that use the hashtag pairs', 'Total number of occurrences') |
|
) |
|
|
|
|
|
if st.button("Generate Hashtag Network", type="primary"): |
|
with st.spinner("Building graph, filtering edges, and detecting communities..."): |
|
|
|
hashtag_user_pairs = [(tuple(sorted(combination)), userid) for hashtags, userid in zip(st.session_state.dfx['Hashtags'], st.session_state.dfx['User Name']) for combination in combinations(hashtags, r=2)] |
|
hashtag_user_df = pd.DataFrame(hashtag_user_pairs, columns=['hashtag_pair', 'User Name']) |
|
if weight_option == 'Number of users that use the hashtag pairs': |
|
edge_df = hashtag_user_df.groupby('hashtag_pair').agg({'User Name': 'nunique'}).reset_index() |
|
else: |
|
edge_df = hashtag_user_df.groupby('hashtag_pair').size().reset_index(name='User Name') |
|
edge_df = edge_df.rename(columns={'User Name': 'weight'}) |
|
edge_df[['hashtag1', 'hashtag2']] = pd.DataFrame(edge_df['hashtag_pair'].tolist(), index=edge_df.index) |
|
edge_list = edge_df[['hashtag1', 'hashtag2', 'weight']] |
|
G = nx.from_pandas_edgelist(edge_list, 'hashtag1', 'hashtag2', 'weight') |
|
G_backbone = disparity_filter(G, weight='weight', alpha=0.05) |
|
communities = list(nx.community.louvain_communities(G_backbone, weight='weight', seed=1234)) |
|
communities.sort(key=len, reverse=True) |
|
for i, community in enumerate(communities): |
|
for node in community: |
|
G_backbone.nodes[node]['community'] = i |
|
sorted_community_hashtags = pd.DataFrame([ |
|
[h for h, _ in sorted(((h, G.degree(h, weight='weight')) for h in com), key=lambda x: x[1], reverse=True)] |
|
for com in communities |
|
]).T |
|
sorted_community_hashtags.columns = [f'Community {i+1}' for i in range(len(sorted_community_hashtags.columns))] |
|
|
|
|
|
df_community_names = pd.DataFrame( |
|
sorted_community_hashtags.columns, |
|
columns=['community_names'], |
|
index=sorted_community_hashtags.columns |
|
) |
|
st.session_state.community_names_df = df_community_names |
|
|
|
st.session_state.hashtag_results = { |
|
"G_backbone": G_backbone, |
|
"communities": communities, |
|
"sorted_community_hashtags": sorted_community_hashtags, |
|
"edge_list": edge_list |
|
} |
|
st.rerun() |
|
|
|
|
|
if 'hashtag_results' in st.session_state: |
|
results = st.session_state.hashtag_results |
|
G_backbone = results['G_backbone'] |
|
communities = results['communities'] |
|
sorted_community_hashtags = results['sorted_community_hashtags'] |
|
edge_list = results['edge_list'] |
|
|
|
st.success(f"Network generated! Found **{len(communities)}** communities from **{len(G_backbone.nodes)}** hashtags and **{len(G_backbone.edges)}** connections.") |
|
|
|
|
|
tab_graph, tab_editor, tab_timeline, tab_lists = st.tabs([ |
|
"📊 Network Graph", |
|
"📝 Edit Community Names", |
|
"🕒 Community Timelines", |
|
"📋 Community & Edge Lists" |
|
]) |
|
|
|
with tab_graph: |
|
st.markdown("### Hashtag Network Graph") |
|
st.markdown("Nodes represent hashtags, colored by community. The legend uses the names from the 'Edit Community Names' tab.") |
|
|
|
|
|
layout_options = { |
|
"Spring": "spring", |
|
"Kamada-Kawai": "kamada_kawai", |
|
"Circular": "circular", |
|
"Spectral": "spectral" |
|
} |
|
selected_layout_name = st.selectbox( |
|
"Graph Layout Algorithm", |
|
options=layout_options.keys() |
|
) |
|
|
|
|
|
layout_alg_str = layout_options[selected_layout_name] |
|
|
|
|
|
community_names_lookup = st.session_state.community_names_df['community_names'].to_dict() |
|
|
|
|
|
fig = plot_graph( |
|
_G=G_backbone, |
|
layout_name=layout_alg_str, |
|
community_names_lookup=community_names_lookup |
|
) |
|
st.plotly_chart(fig, use_container_width=True) |
|
|
|
with tab_editor: |
|
st.markdown("### Edit Community Names") |
|
st.markdown("Change the default community names in the table below. The new names will automatically update the graph legend and the timeline chart.") |
|
|
|
|
|
edited_df = st.data_editor( |
|
st.session_state.community_names_df, |
|
use_container_width=True, |
|
num_rows="dynamic" |
|
) |
|
|
|
|
|
st.session_state.community_names_df = edited_df |
|
|
|
st.download_button( |
|
label="Download Community Names as CSV", |
|
data=edited_df.to_csv().encode("utf-8"), |
|
file_name="community_names.csv", |
|
mime="text/csv", |
|
) |
|
|
|
with tab_timeline: |
|
st.markdown("### Community Size Over Time") |
|
|
|
|
|
community_names_lookup = st.session_state.community_names_df['community_names'].to_dict() |
|
|
|
selected_communities = st.multiselect('Select Communities', community_names_lookup.values(), default=list(community_names_lookup.values())) |
|
resample_dict = {'Day': 'D', 'Week': 'W', 'Month': 'M', 'Quarter': 'Q', 'Year': 'Y'} |
|
resample_time = st.selectbox('Select Time Resampling', list(resample_dict.keys()), index=4) |
|
|
|
community_dict = {node: community_names_lookup.get(f'Community {i+1}') for i, comm_set in enumerate(communities) for node in comm_set} |
|
|
|
df_communities = st.session_state.dfx.copy() |
|
df_communities['Communities'] = df_communities['Hashtags'].apply(lambda tags: list(set(community_dict.get(tag) for tag in tags if tag in community_dict))) |
|
df_communities = df_communities.explode('Communities').dropna(subset=['Communities']) |
|
df_ts = df_communities.set_index('Post Created') |
|
df_community_sizes = df_ts.groupby([pd.Grouper(freq=resample_dict[resample_time]), 'Communities']).size().unstack(fill_value=0) |
|
|
|
existing_selected_cols = [col for col in selected_communities if col in df_community_sizes.columns] |
|
if existing_selected_cols: |
|
st.area_chart(df_community_sizes[existing_selected_cols]) |
|
else: |
|
st.warning("No data available for the selected communities.") |
|
|
|
with tab_lists: |
|
st.markdown("### Hashtag Communities (by importance)") |
|
st.dataframe(sorted_community_hashtags) |
|
st.markdown("### Top Edge Pairs (by weight)") |
|
st.dataframe(edge_list.sort_values(by='weight', ascending=False).head(100)) |
|
|
|
elif selected_menu_option == "Clustering": |
|
st.markdown("## Clustering of Posts") |
|
st.markdown("This section allows you to group posts based on the similarity of their text or image content. Use the sidebar to configure the clustering process, then click 'Run Clustering' to see the results.") |
|
|
|
|
|
st.sidebar.markdown("# Clustering Options") |
|
st.sidebar.markdown("### Data & Algorithm") |
|
type_embeddings = st.sidebar.selectbox("Cluster based on:", ["Image", "Text"]) |
|
clustering_algo = st.sidebar.selectbox("Clustering Algorithm:", ["MiniBatchKMeans", "HDBSCAN", "KMeans"]) |
|
st.sidebar.info(f"**Tip:** `MiniBatchKMeans` is the fastest for a quick overview.") |
|
|
|
st.sidebar.markdown("### Algorithm Settings") |
|
if clustering_algo in ["KMeans", "MiniBatchKMeans"]: |
|
n_clusters = st.sidebar.slider("Number of Clusters (k)", 2, 50, 5, key="n_clusters_slider") |
|
if clustering_algo == "MiniBatchKMeans": |
|
batch_size = st.sidebar.slider("Batch Size", 32, 1024, 256, 32, help="Number of samples to use in each mini-batch.") |
|
max_iter = st.sidebar.slider("Max Iterations", 50, 500, 100, 50, help="Maximum number of iterations.") |
|
else: |
|
batch_size, max_iter = None, None |
|
min_cluster_size, min_samples = None, None |
|
elif clustering_algo == "HDBSCAN": |
|
min_cluster_size = st.sidebar.slider("Minimum Cluster Size", 2, 200, 15, help="Smallest size for a group to be a cluster.") |
|
min_samples = st.sidebar.slider("Minimum Samples", 1, 50, 5, help="Larger values lead to more points being declared as noise.") |
|
n_clusters, batch_size, max_iter = None, None, None |
|
|
|
st.sidebar.markdown("### Dimensionality Reduction") |
|
dim_reduction = st.sidebar.selectbox("Reduction Method:", ["PCA", "UMAP"]) |
|
st.sidebar.info(f"**Tip:** `PCA` is much faster than `UMAP`.") |
|
if dim_reduction == "UMAP": |
|
n_components = st.sidebar.slider("Number of Components", 2, 80, 50, help="Dimensions to reduce to before clustering.") |
|
n_neighbors = st.sidebar.slider("Number of Neighbors", 2, 50, 15, help="Controls UMAP's balance of local/global structure.") |
|
min_dist = st.sidebar.slider("Minimum Distance", 0.0, 1.0, 0.0, help="Controls how tightly UMAP packs points.") |
|
else: |
|
n_components = st.sidebar.slider("Number of Components", 2, 80, 2) |
|
n_neighbors, min_dist = None, None |
|
|
|
|
|
|
|
|
|
if st.button("Run Clustering", type="primary"): |
|
with st.spinner("Clustering embeddings... This may take a moment."): |
|
if type_embeddings == "Text": |
|
embeddings = dataset['text_embs'] |
|
else: |
|
embeddings = dataset['img_embs'] |
|
|
|
|
|
labels, reduced_embeddings = cluster_embeddings( |
|
embeddings, |
|
clustering_algo=clustering_algo, |
|
dim_reduction=dim_reduction, |
|
n_clusters=n_clusters, |
|
min_cluster_size=min_cluster_size, |
|
n_components=n_components, |
|
n_neighbors=n_neighbors, |
|
min_dist=min_dist, |
|
min_samples=min_samples, |
|
batch_size=batch_size, |
|
max_iter=max_iter |
|
) |
|
|
|
st.session_state['cluster_results'] = { |
|
"labels": labels, |
|
"reduced_embeddings": reduced_embeddings, |
|
"type_embeddings": type_embeddings, |
|
"clustering_algo": clustering_algo, |
|
"dim_reduction": dim_reduction |
|
} |
|
st.rerun() |
|
|
|
|
|
if 'cluster_results' in st.session_state: |
|
|
|
results = st.session_state['cluster_results'] |
|
labels = results['labels'] |
|
reduced_embeddings = results['reduced_embeddings'] |
|
|
|
num_found_clusters = len(set(labels) - {-1}) |
|
st.success(f"Clustering complete! Found **{num_found_clusters}** clusters using **{results['clustering_algo']}** on **{results['type_embeddings']}** embeddings with **{results['dim_reduction']}** reduction.") |
|
|
|
df_clustered = df.copy() |
|
df_clustered['cluster'] = labels |
|
|
|
|
|
tab1, tab2, tab3 = st.tabs(["📊 Results Table", "📈 2D Visualization", "🕒 Time Series Analysis"]) |
|
|
|
with tab1: |
|
st.markdown("### Clustered Data") |
|
st.dataframe( |
|
data=filter_dataframe(df_clustered), |
|
column_config={ |
|
"image": st.column_config.ImageColumn("Image", help="Instagram image"), |
|
"URL": st.column_config.LinkColumn("Link", help="Instagram link", width="small") |
|
}, |
|
hide_index=True, |
|
use_container_width=True |
|
) |
|
st.download_button( |
|
"Download Clustered Data as CSV", |
|
df_clustered.to_csv(index=False).encode('utf-8'), |
|
f'clustered_data_{datetime.now().strftime("%Y%m%d-%H%M%S")}.csv', |
|
"text/csv", |
|
key='download-cluster-csv' |
|
) |
|
|
|
with tab2: |
|
st.markdown("### Cluster Visualization") |
|
if reduced_embeddings.shape[1] > 2: |
|
with st.spinner("Reducing dimensions for 2D visualization..."): |
|
vis_reducer = umap.UMAP(n_components=2, random_state=42) |
|
vis_embeddings = vis_reducer.fit_transform(reduced_embeddings) |
|
else: |
|
vis_embeddings = reduced_embeddings |
|
|
|
df_plot_bokeh = pd.DataFrame(vis_embeddings, columns=('x', 'y')) |
|
df_plot_bokeh['description_clean'] = df_clustered['description_clean'] |
|
df_plot_bokeh['image_url'] = df_clustered['image'] |
|
df_plot_bokeh['cluster'] = labels |
|
|
|
unique_labels = sorted(list(set(labels))) |
|
color_dict = {label: rgb2hex(cc.glasbey_hv[i % len(cc.glasbey_hv)]) for i, label in enumerate(unique_labels)} |
|
df_plot_bokeh['color'] = df_plot_bokeh['cluster'].map(color_dict) |
|
|
|
source = ColumnDataSource(data=df_plot_bokeh) |
|
TOOLTIPS = """ |
|
<div style="width: 200px; padding: 5px; background-color: #f0f0f0; border-radius: 5px; font-family: sans-serif; border: 1px solid #cccccc;"> |
|
<div> |
|
<img src="@image_url" |
|
height="150" |
|
width="150" |
|
style="display: block; margin: auto;" |
|
border="0"> |
|
</img> |
|
</div> |
|
<hr style="border: 1px solid #aaaaaa; margin: 8px 0;"> |
|
<div style="text-align: left; padding: 0 5px;"> |
|
<span style="font-size: 12px; font-weight: bold;">Cluster: @cluster</span><br> |
|
<span style="font-size: 11px; word-wrap: break-word;">@description_clean</span> |
|
</div> |
|
</div> |
|
""" |
|
|
|
p = figure(width=800, height=800, tooltips=TOOLTIPS, title="2D Visualization of Post Clusters") |
|
p.circle('x', 'y', size=10, source=source, color='color', legend_field='cluster', line_color=None, alpha=0.8) |
|
p.legend.title = 'Cluster' |
|
p.legend.location = "top_left" |
|
st.bokeh_chart(p, use_container_width=True) |
|
|
|
with tab3: |
|
st.markdown("### Cluster Analysis Over Time") |
|
|
|
|
|
resample_dict = { |
|
'Day': 'D', |
|
'Week': 'W', |
|
'Month': 'M', |
|
'Quarter': 'Q', |
|
'Year': 'Y' |
|
} |
|
|
|
variable = st.selectbox('Select Variable for Time Series:', ['Likes', 'Comments', 'Followers at Posting', 'Total Interactions'], key="cluster_ts_var") |
|
resample_time = st.selectbox('Resample Time By:', list(resample_dict.keys()), index=2, key="cluster_ts_resample") |
|
|
|
df_ts = df_clustered.copy() |
|
df_ts['Post Created'] = pd.to_datetime(df_ts['Post Created']) |
|
df_ts = df_ts.set_index('Post Created') |
|
df_ts = df_ts[df_ts['cluster'] != -1] |
|
|
|
if not df_ts.empty: |
|
|
|
df_plot = df_ts.groupby([pd.Grouper(freq=resample_dict[resample_time]), 'cluster'])[variable].sum().unstack(fill_value=0) |
|
st.line_chart(df_plot) |
|
else: |
|
st.warning("No data available for plotting (all points may have been classified as noise).") |
|
|
|
elif selected_menu_option == "Stats": |
|
st.markdown("### Time Series Analysis") |
|
|
|
variable = st.selectbox('Select Variable', ['Followers at Posting', 'Total Interactions', 'Likes', 'Comments']) |
|
|
|
|
|
resample_dict = { |
|
'Day': 'D', |
|
'Three Days': '3D', |
|
'Week': 'W', |
|
'Two Weeks': '2W', |
|
'Month': 'M', |
|
'Quarter': 'Q', |
|
'Year': 'Y' |
|
} |
|
|
|
|
|
resample_time = st.selectbox('Select Time Resampling', list(resample_dict.keys())) |
|
|
|
df_filtered = df.set_index('Post Created') |
|
|
|
|
|
min_date = df_filtered.index.min().date() |
|
max_date = df_filtered.index.max().date() |
|
|
|
date_range = st.slider('Select Date Range', min_value=min_date, max_value=max_date, value=(min_date, max_date)) |
|
|
|
|
|
df_filtered = df_filtered[(df_filtered.index.date >= date_range[0]) & (df_filtered.index.date <= date_range[1])] |
|
|
|
|
|
df_resampled = df_filtered[variable].resample(resample_dict[resample_time]).sum() |
|
st.line_chart(df_resampled) |
|
|
|
st.markdown("### Correlation Analysis") |
|
|
|
options = ['Followers at Posting', 'Total Interactions', 'Likes', 'Comments'] |
|
scatter_variable_1 = st.selectbox('Select Variable 1 for Scatter Plot', options) |
|
|
|
scatter_variable_2 = st.selectbox('Select Variable 2 for Scatter Plot', options) |
|
|
|
|
|
st.write(f"Scatter Plot of {scatter_variable_1} vs {scatter_variable_2}") |
|
|
|
scatter_fig = px.scatter(df_filtered, x=scatter_variable_1, y=scatter_variable_2) |
|
|
|
st.plotly_chart(scatter_fig) |
|
|
|
|
|
corr = df_filtered[scatter_variable_1].corr(df_filtered[scatter_variable_2]) |
|
if corr > 0.7: |
|
st.write(f"The correlation coefficient is {corr}, indicating a strong positive relationship between {scatter_variable_1} and {scatter_variable_2}.") |
|
elif corr > 0.3: |
|
st.write(f"The correlation coefficient is {corr}, indicating a moderate positive relationship between {scatter_variable_1} and {scatter_variable_2}.") |
|
elif corr > -0.3: |
|
st.write(f"The correlation coefficient is {corr}, indicating a weak or no relationship between {scatter_variable_1} and {scatter_variable_2}.") |
|
elif corr > -0.7: |
|
st.write(f"The correlation coefficient is {corr}, indicating a moderate negative relationship between {scatter_variable_1} and {scatter_variable_2}.") |
|
else: |
|
st.write(f"The correlation coefficient is {corr}, indicating a strong negative relationship between {scatter_variable_1} and {scatter_variable_2}.") |