rjadr's picture
Update app.py
95bd107
raw
history blame
18.8 kB
import pandas as pd
import streamlit as st
import datasets
import plotly.express as px
from sentence_transformers import SentenceTransformer, util
import os
from pandas.api.types import (
is_categorical_dtype,
is_datetime64_any_dtype,
is_numeric_dtype,
is_object_dtype,
)
import subprocess
st.set_page_config(layout="wide")
model_dir = "./models/sbert.net_models_sentence-transformers_clip-ViT-B-32-multilingual-v1"
@st.cache_data(show_spinner=True)
def download_models():
# Directory doesn't exist, download and extract the model
subprocess.run(["mkdir", "models"])
subprocess.run(["wget", "--no-check-certificate", "https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/v0.2/clip-ViT-B-32-multilingual-v1.zip"], check=True)
subprocess.run(["unzip", "-q", "clip-ViT-B-32-multilingual-v1.zip", "-d", model_dir], check=True)
token = os.getenv('token')
@st.cache_data(show_spinner=True)
def load_dataset():
dataset = datasets.load_dataset('rjadr/ditaduranuncamais', split='train', use_auth_token=token)
dataset.add_faiss_index(column="txt_embs")
dataset.add_faiss_index(column="img_embs")
dataset = dataset.remove_columns(['Post Created','Post Created Time','Like and View Counts Disabled','Link','Photo','Title','Sponsor Id','Sponsor Name','Download URL', 'image', 'Views', 'text_full'])
return dataset
@st.cache_data(show_spinner=False)
def load_dataframe(_dataset):
dataframe = _dataset.remove_columns(['txt_embs', 'img_embs']).to_pandas()
dataframe['image_base64'] = dataframe['image_base64'].str.decode('utf-8')
dataframe['Overperforming Score (weighted β€” Likes 1x Comments 1x )'] = dataframe['Overperforming Score (weighted β€” Likes 1x Comments 1x )'].str.replace(',','').astype(float)
dataframe['Total Interactions'] = dataframe['Total Interactions'].str.replace(',','').astype(int)
return dataframe
@st.cache_resource(show_spinner=True)
def load_img_model():
# We use the original clip-ViT-B-32 for encoding images
return SentenceTransformer('clip-ViT-B-32')
@st.cache_resource(show_spinner=True)
def load_txt_model():
# Our text embedding model is aligned to the img_model and maps 50+
# languages to the same vector space
return SentenceTransformer('./models/sbert.net_models_sentence-transformers_clip-ViT-B-32-multilingual-v1')
def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame:
"""
Adds a UI on top of a dataframe to let viewers filter columns
Args:
df (pd.DataFrame): Original dataframe
Returns:
pd.DataFrame: Filtered dataframe
"""
modify = st.checkbox("Add filters")
if not modify:
return df
df = df.copy()
# Try to convert datetimes into a standard format (datetime, no timezone)
for col in df.columns:
if is_object_dtype(df[col]):
try:
df[col] = pd.to_datetime(df[col])
except Exception:
pass
if is_datetime64_any_dtype(df[col]):
df[col] = df[col].dt.tz_localize(None)
modification_container = st.container()
with modification_container:
to_filter_columns = st.multiselect("Filter dataframe on", df.columns)
for column in to_filter_columns:
left, right = st.columns((1, 20))
left.write("↳")
# Treat columns with < 10 unique values as categorical
if is_categorical_dtype(df[column]) or df[column].nunique() < 10:
user_cat_input = right.multiselect(
f"Values for {column}",
df[column].unique(),
default=list(df[column].unique()),
)
df = df[df[column].isin(user_cat_input)]
elif is_numeric_dtype(df[column]):
_min = float(df[column].min())
_max = float(df[column].max())
step = (_max - _min) / 100
user_num_input = right.slider(
f"Values for {column}",
_min,
_max,
(_min, _max),
step=step,
)
df = df[df[column].between(*user_num_input)]
elif is_datetime64_any_dtype(df[column]):
user_date_input = right.date_input(
f"Values for {column}",
value=(
df[column].min(),
df[column].max(),
),
)
if len(user_date_input) == 2:
user_date_input = tuple(map(pd.to_datetime, user_date_input))
start_date, end_date = user_date_input
df = df.loc[df[column].between(start_date, end_date)]
else:
user_text_input = right.text_input(
f"Substring or regex in {column}",
)
if user_text_input:
df = df[df[column].str.contains(user_text_input)]
return df
@st.cache_data
def get_image_embs(image):
"""
Get image embeddings
Parameters:
uploaded_file (PIL.Image): Uploaded image file
Returns:
img_emb (np.array): Image embeddings
"""
img_emb = image_model.encode(image)
return img_emb
@st.cache_data(show_spinner=False)
def get_text_embs(text):
"""
Get text embeddings
Parameters:
text (str): Text to encode
Returns:
text_emb (np.array): Text embeddings
"""
txt_emb = text_model.encode(text)
return txt_emb
@st.cache_data
def postprocess_results(scores, samples):
"""
Postprocess results to tuple of labels and scores
Parameters:
scores (np.array): Scores
samples (datasets.Dataset): Samples
Returns:
labels (list): List of tuples of PIL images and labels/scores
"""
samples_df = pd.DataFrame.from_dict(samples)
samples_df["score"] = scores
samples_df["score"] = (1 - (samples_df["score"] - samples_df["score"].min()) / (
samples_df["score"].max() - samples_df["score"].min())) * 100
samples_df["score"] = samples_df["score"].astype(int)
samples_df.reset_index(inplace=True, drop=True)
samples_df = samples_df[['Post Created Date', 'image_base64', 'Description', 'Image Text', 'Account', 'User Name'] + [col for col in samples_df.columns if col not in ['Post Created Date', 'image_base64', 'Description', 'Image Text', 'Account', 'User Name']]]
return samples_df.drop(columns=['txt_embs', 'img_embs'])
@st.cache_data
def text_to_text(text, k=5):
"""
Text to text
Parameters:
text (str): Input text
k (int): Number of top results to return
Returns:
results (list): List of tuples of PIL images and labels/scores
"""
text_emb = get_text_embs(text)
scores, samples = dataset.get_nearest_examples('txt_embs', text_emb, k=k)
return postprocess_results(scores, samples)
@st.cache_data
def image_to_text(image, k=5):
"""
Image to text
Parameters:
image (str): Temp filepath to image
k (int): Number of top results to return
Returns:
results (list): List of tuples of PIL images and labels/scores
"""
img_emb = get_image_embs(image.name)
scores, samples = dataset.get_nearest_examples('txt_embs', img_emb, k=k)
return postprocess_results(scores, samples)
@st.cache_data
def text_to_image(text, k=5):
"""
Text to image
Parameters:
text (str): Input text
k (int): Number of top results to return
Returns:
results (list): List of tuples of PIL images and labels/scores
"""
text_emb = get_text_embs(text)
scores, samples = dataset.get_nearest_examples('img_embs', text_emb, k=k)
return postprocess_results(scores, samples)
@st.cache_data
def image_to_image(image, k=5):
"""
Image to image
Parameters:
image (str): Temp filepath to image
k (int): Number of top results to return
Returns:
results (list): List of tuples of PIL images and labels/scores
"""
img_emb = get_image_embs(image.name)
scores, samples = dataset.get_nearest_examples('img_embs', img_emb, k=k)
return postprocess_results(scores, samples)
def check_password():
"""Returns `True` if the user had the correct password."""
def password_entered():
"""Checks whether a password entered by the user is correct."""
if st.session_state["password"] == st.secrets["password"]:
st.session_state["password_correct"] = True
del st.session_state["password"] # don't store password
else:
st.session_state["password_correct"] = False
if "password_correct" not in st.session_state:
# First run, show input for password.
st.text_input(
"Password", type="password", on_change=password_entered, key="password"
)
return False
elif not st.session_state["password_correct"]:
# Password not correct, show input + error.
st.text_input(
"Password", type="password", on_change=password_entered, key="password"
)
st.error("πŸ˜• Password incorrect")
return False
else:
# Password correct.
return True
if check_password():
# Check if the directory exists
if not os.path.exists(model_dir):
download_models()
dataset = load_dataset()
df = load_dataframe(dataset)
image_model = load_img_model()
text_model = load_txt_model()
st.title("#ditaduranuncamais Data Explorer")
tab1, tab2, tab3 = st.tabs(["Data exploration", "Semantic search", "Stats"])
with tab1:
# Initialization
if 'rows_per_page' not in st.session_state:
st.session_state['rows_per_page'] = 25
if 'page_number' not in st.session_state:
st.session_state['page_number'] = 1
filtered_df = filter_dataframe(df)
max_page = -(-len(filtered_df) // st.session_state['rows_per_page']) # ceiling division
start_index = st.session_state['rows_per_page'] * (st.session_state['page_number'] - 1)
end_index = start_index + st.session_state['rows_per_page']
sub_df = filtered_df.iloc[start_index:end_index]
# sort columms order: Post Created Date, image_base64, Description, Image Text, Account, User Name and then the rest
sub_df = sub_df[['Post Created Date', 'image_base64', 'Description', 'Image Text', 'Account', 'User Name'] + [col for col in sub_df.columns if col not in ['Post Created Date', 'image_base64', 'Description', 'Image Text', 'Account', 'User Name']]]
col1, col2, col3, col4 = st.columns(4)
with col4:
rows_per_page = st.selectbox('Rows per page', [25, 50, 75, 100, 150, 200], index=0, key='rows_per_page_select')
if rows_per_page != st.session_state['rows_per_page']:
st.session_state['rows_per_page'] = rows_per_page
st.session_state['page_number'] = 1 # Reset page number when rows per page changes
st.experimental_rerun()
with col2:
page_select = st.selectbox('Jump to page', options=range(1, max_page + 1), index=st.session_state['page_number']-1, key='page_number_select')
if page_select != st.session_state['page_number']:
st.session_state['page_number'] = page_select
st.experimental_rerun()
with col1:
if st.button('Previous'):
st.session_state['page_number'] = max(1, st.session_state['page_number'] - 1)
st.experimental_rerun()
with col3:
if st.button('Next'):
st.session_state['page_number'] = min(max_page, st.session_state['page_number'] + 1)
st.experimental_rerun()
st.dataframe(
data=sub_df,
column_config={
"image_base64": st.column_config.ImageColumn(
"image", help="Instagram image"
),
"URL": st.column_config.LinkColumn(
"link", help="Instagram link", width="small"
)
},
# hide_index=True,
)
with tab2:
tabs = ["Text to Text", "Text to Image", "Image to Image", "Image to Text"]
selected_tab = st.radio("Select a search type", tabs)
if selected_tab == "Text to Text":
text_to_text_input = st.text_input("Enter text")
text_to_text_k_top = st.slider("Number of results", 1, 20, 8)
if st.button("Search"):
st.dataframe(
data=text_to_text(text_to_text_input, text_to_text_k_top),
column_config={
"image_base64": st.column_config.ImageColumn(
"image", help="Instagram image"
),
"URL": st.column_config.LinkColumn(
"link", help="Instagram link", width="small"
)
},
hide_index=True,
)
elif selected_tab == "Text to Image":
text_to_image_input = st.text_input("Enter text")
text_to_image_k_top = st.slider("Number of results", 1, 20, 8)
if st.button("Search"):
st.dataframe(
data=text_to_image(text_to_image_input, text_to_image_k_top),
column_config={
"image_base64": st.column_config.ImageColumn(
"image", help="Instagram image"
),
"URL": st.column_config.LinkColumn(
"link", help="Instagram link", width="small"
)
},
hide_index=True,
)
elif selected_tab == "Image to Image":
image_to_image_k_top = st.slider("Number of results", 1, 20, 8)
image_to_image_input = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if st.button("Search"):
st.dataframe(
data=image_to_image(image_to_image_input, image_to_image_k_top),
column_config={
"image_base64": st.column_config.ImageColumn(
"image", help="Instagram image"
),
"URL": st.column_config.LinkColumn(
"link", help="Instagram link", width="small"
)
},
hide_index=True,
)
elif selected_tab == "Image to Text":
image_to_text_k_top = st.slider("Number of results", 1, 20, 8)
image_to_text_input = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if st.button("Search"):
st.dataframe(
data=image_to_text(image_to_text_input, image_to_text_k_top),
column_config={
"image_base64": st.column_config.ImageColumn(
"image", help="Instagram image"
),
"URL": st.column_config.LinkColumn(
"link", help="Instagram link", width="small"
)
},
hide_index=True,
)
with tab3:
st.markdown("### Time Series Analysis")
# Dropdown to select variables
variable = st.selectbox('Select Variable', ['Followers at Posting', 'Total Interactions', 'Likes', 'Comments'])
# Dropdown to select time resampling
resample_dict = {
'Day': 'D',
'Three Days': '3D',
'Week': 'W',
'Two Weeks': '2W',
'Month': 'M',
'Quarter': 'Q',
'Year': 'Y'
}
# Dropdown to select time resampling
resample_time = st.selectbox('Select Time Resampling', list(resample_dict.keys()))
df_filtered = df.set_index('Post Created Date')
# Slider for date range selection
min_date = df_filtered.index.min().date()
max_date = df_filtered.index.max().date()
date_range = st.slider('Select Date Range', min_value=min_date, max_value=max_date, value=(min_date, max_date))
# Filter dataframe based on selected date range
df_filtered = df_filtered[(df_filtered.index.date >= date_range[0]) & (df_filtered.index.date <= date_range[1])]
# Create a separate DataFrame for resampling and plotting
df_resampled = df_filtered[variable].resample(resample_dict[resample_time]).sum()
st.line_chart(df_resampled)
st.markdown("### Correlation Analysis")
# Dropdown to select variables for scatter plot
options = ['Followers at Posting', 'Total Interactions', 'Likes', 'Comments']
scatter_variable_1 = st.selectbox('Select Variable 1 for Scatter Plot', options)
# options.remove(scatter_variable_1) # remove the chosen option from the list
scatter_variable_2 = st.selectbox('Select Variable 2 for Scatter Plot', options)
# Plot scatter chart
st.write(f"Scatter Plot of {scatter_variable_1} vs {scatter_variable_2}")
# Plot scatter chart
scatter_fig = px.scatter(df_filtered, x=scatter_variable_1, y=scatter_variable_2, trendline='ols', trendline_color_override='red')
st.plotly_chart(scatter_fig)
# calculate correlation for scatter_variable_1 with scatter_variable_2
corr = df_filtered[scatter_variable_1].corr(df_filtered[scatter_variable_2])
if corr > 0.7:
st.write(f"The correlation coefficient is {corr}, indicating a strong positive relationship between {scatter_variable_1} and {scatter_variable_2}.")
elif corr > 0.3:
st.write(f"The correlation coefficient is {corr}, indicating a moderate positive relationship between {scatter_variable_1} and {scatter_variable_2}.")
elif corr > -0.3:
st.write(f"The correlation coefficient is {corr}, indicating a weak or no relationship between {scatter_variable_1} and {scatter_variable_2}.")
elif corr > -0.7:
st.write(f"The correlation coefficient is {corr}, indicating a moderate negative relationship between {scatter_variable_1} and {scatter_variable_2}.")
else:
st.write(f"The correlation coefficient is {corr}, indicating a strong negative relationship between {scatter_variable_1} and {scatter_variable_2}.")