import pandas as pd import streamlit as st import datasets import plotly.express as px from sentence_transformers import SentenceTransformer, util import os from pandas.api.types import ( is_categorical_dtype, is_datetime64_any_dtype, is_numeric_dtype, is_object_dtype, ) import subprocess st.set_page_config(layout="wide") model_dir = "./models/sbert.net_models_sentence-transformers_clip-ViT-B-32-multilingual-v1" @st.cache_data(show_spinner=True) def download_models(): # Directory doesn't exist, download and extract the model subprocess.run(["mkdir", "models"]) subprocess.run(["wget", "--no-check-certificate", "https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/v0.2/clip-ViT-B-32-multilingual-v1.zip"], check=True) subprocess.run(["unzip", "-q", "clip-ViT-B-32-multilingual-v1.zip", "-d", model_dir], check=True) # Check if the directory exists if not os.path.exists(model_dir): download_models() token = os.getenv('token') @st.cache_data(show_spinner=True) def load_dataset(): dataset = datasets.load_dataset('rjadr/ditaduranuncamais', split='train', use_auth_token=token) dataset.add_faiss_index(column="txt_embs") dataset.add_faiss_index(column="img_embs") dataset = dataset.remove_columns(['Post Created Date','Post Created Time','Type','Like and View Counts Disabled','Link','Photo','Title','Sponsor Id','Sponsor Name']) return dataset @st.cache_data(show_spinner=False) def load_dataframe(_dataset): dataframe = _dataset.remove_columns(['txt_embs', 'img_embs', 'image_base64']).to_pandas() return dataframe @st.cache_resource(show_spinner=True) def load_img_model(): # We use the original clip-ViT-B-32 for encoding images return SentenceTransformer('clip-ViT-B-32') @st.cache_resource(show_spinner=True) def load_txt_model(): # Our text embedding model is aligned to the img_model and maps 50+ # languages to the same vector space return SentenceTransformer('./models/sbert.net_models_sentence-transformers_clip-ViT-B-32-multilingual-v1') def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame: """ Adds a UI on top of a dataframe to let viewers filter columns Args: df (pd.DataFrame): Original dataframe Returns: pd.DataFrame: Filtered dataframe """ modify = st.checkbox("Add filters") if not modify: return df df = df.copy() # Try to convert datetimes into a standard format (datetime, no timezone) for col in df.columns: if is_object_dtype(df[col]): try: df[col] = pd.to_datetime(df[col]) except Exception: pass if is_datetime64_any_dtype(df[col]): df[col] = df[col].dt.tz_localize(None) modification_container = st.container() with modification_container: to_filter_columns = st.multiselect("Filter dataframe on", df.columns) for column in to_filter_columns: left, right = st.columns((1, 20)) left.write("↳") # Treat columns with < 10 unique values as categorical if is_categorical_dtype(df[column]) or df[column].nunique() < 10: user_cat_input = right.multiselect( f"Values for {column}", df[column].unique(), default=list(df[column].unique()), ) df = df[df[column].isin(user_cat_input)] elif is_numeric_dtype(df[column]): _min = float(df[column].min()) _max = float(df[column].max()) step = (_max - _min) / 100 user_num_input = right.slider( f"Values for {column}", _min, _max, (_min, _max), step=step, ) df = df[df[column].between(*user_num_input)] elif is_datetime64_any_dtype(df[column]): user_date_input = right.date_input( f"Values for {column}", value=( df[column].min(), df[column].max(), ), ) if len(user_date_input) == 2: user_date_input = tuple(map(pd.to_datetime, user_date_input)) start_date, end_date = user_date_input df = df.loc[df[column].between(start_date, end_date)] else: user_text_input = right.text_input( f"Substring or regex in {column}", ) if user_text_input: df = df[df[column].str.contains(user_text_input)] return df @st.cache_data def get_image_embs(image): """ Get image embeddings Parameters: uploaded_file (PIL.Image): Uploaded image file Returns: img_emb (np.array): Image embeddings """ img_emb = image_model.encode(image, convert_to_tensor=True) return img_emb @st.cache_data(show_spinner=False) def get_text_embs(text): """ Get text embeddings Parameters: text (str): Text to encode Returns: text_emb (np.array): Text embeddings """ txt_emb = text_model.encode(text, convert_to_tensor=True) return txt_emb @st.cache_data def postprocess_results(scores, samples): """ Postprocess results to tuple of labels and scores Parameters: scores (np.array): Scores samples (datasets.Dataset): Samples Returns: labels (list): List of tuples of PIL images and labels/scores """ samples_df = pd.DataFrame.from_dict(samples) samples_df["scores"] = scores samples_df["scores"] = (1 - (samples_df["scores"] - samples_df["scores"].min()) / ( samples_df["scores"].max() - samples_df["scores"].min())) * 100 samples_df["scores"] = samples_df["scores"].astype(int) samples_df.reset_index(inplace=True, drop=True) samples_df['label'] = samples_df['text_full'] + ' (' + samples_df['scores'].astype(str) + '%)' return samples_df.drop(columns=['txt_embs', 'img_embs']) @st.cache_data def text_to_text(text, k=5): """ Text to text Parameters: text (str): Input text k (int): Number of top results to return Returns: results (list): List of tuples of PIL images and labels/scores """ text_emb = get_text_embs(text) scores, samples = dataset.get_nearest_examples('txt_embs', text_emb, k=k) return postprocess_results(scores, samples) @st.cache_data def image_to_text(image, k=5): """ Image to text Parameters: image (str): Temp filepath to image k (int): Number of top results to return Returns: results (list): List of tuples of PIL images and labels/scores """ img_emb = get_image_embs(image.name) scores, samples = dataset.get_nearest_examples('txt_embs', img_emb, k=k) return postprocess_results(scores, samples) @st.cache_data def text_to_image(text, k=5): """ Text to image Parameters: text (str): Input text k (int): Number of top results to return Returns: results (list): List of tuples of PIL images and labels/scores """ text_emb = get_text_embs(text) scores, samples = dataset.get_nearest_examples('img_embs', text_emb, k=k) return postprocess_results(scores, samples) @st.cache_data def image_to_image(image, k=5): """ Image to image Parameters: image (str): Temp filepath to image k (int): Number of top results to return Returns: results (list): List of tuples of PIL images and labels/scores """ img_emb = get_image_embs(image.name) scores, samples = dataset.get_nearest_examples('img_embs', img_emb, k=k) return postprocess_results(scores, samples) dataset = load_dataset() df = load_dataframe(dataset) image_model = load_img_model() text_model = load_txt_model() st.title("#ditaduranuncamais Data Explorer") st.title(f'My first app {st.__version__}') tab1, tab2, tab3 = st.tabs(["Data exploration", "Semantic search", "Stats"]) with tab1: st.dataframe( data=filter_dataframe(df), # use_container_width=True, column_config={ "Download URL": st.column_config.ImageColumn( "image", help="Instagram image" ), "URL": st.column_config.LinkColumn( "link", help="Instagram link", width="small" ) }, hide_index=True, ) with tab2: tabs = ["Text to Text", "Text to Image", "Image to Image", "Image to Text"] selected_tab = st.radio("Select a search type", tabs) if selected_tab == "Text to Text": text_to_text_input = st.text_input("Enter text") text_to_text_k_top = st.slider("Number of results", 1, 20, 8) if st.button("Search"): st.dataframe( data=text_to_text(text_to_text_input, text_to_text_k_top), column_config={ "Download URL": st.column_config.ImageColumn( "image", help="Instagram image" ), "URL": st.column_config.LinkColumn( "link", help="Instagram link", width="small" ) }, hide_index=True, ) elif selected_tab == "Text to Image": text_to_image_input = st.text_input("Enter text") text_to_image_k_top = st.slider("Number of results", 1, 20, 8) if st.button("Search"): st.dataframe( data=text_to_image(text_to_image_input, text_to_image_k_top), column_config={ "Download URL": st.column_config.ImageColumn( "image", help="Instagram image" ), "URL": st.column_config.LinkColumn( "link", help="Instagram link", width="small" ) }, hide_index=True, ) elif selected_tab == "Image to Image": image_to_image_k_top = st.slider("Number of results", 1, 20, 8) image_to_image_input = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) if st.button("Search"): st.dataframe( data=image_to_image(image_to_image_input, image_to_image_k_top), column_config={ "Download URL": st.column_config.ImageColumn( "image", help="Instagram image" ), "URL": st.column_config.LinkColumn( "link", help="Instagram link", width="small" ) }, hide_index=True, ) elif selected_tab == "Image to Text": image_to_text_k_top = st.slider("Number of results", 1, 20, 8) image_to_text_input = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) if st.button("Search"): st.dataframe( data=image_to_text(image_to_text_input, image_to_text_k_top), column_config={ "Download URL": st.column_config.ImageColumn( "image", help="Instagram image" ), "URL": st.column_config.LinkColumn( "link", help="Instagram link", width="small" ) }, hide_index=True, ) with tab3: st.markdown("### Time Series Analysis") # Dropdown to select variables variable = st.selectbox('Select Variable', ['Total Interactions', 'Likes', 'Comments', 'Overperforming Score (weighted — Likes 1x Comments 1x )']) # Dropdown to select time resampling resample_dict = { 'Day': 'D', 'Three Days': '3D', 'Week': 'W', 'Two Weeks': '2W', 'Month': 'M', 'Quarter': 'Q', 'Year': 'Y' } # Dropdown to select time resampling resample_time = st.selectbox('Select Time Resampling', list(resample_dict.keys())) df_filtered = df.set_index('Post Created Date') # Slider for date range selection min_date = df_filtered.index.min().date() max_date = df_filtered.index.max().date() date_range = st.slider('Select Date Range', min_value=min_date, max_value=max_date, value=(min_date, max_date)) # Filter dataframe based on selected date range df_filtered = df_filtered[(df_filtered.index.date >= date_range[0]) & (df_filtered.index.date <= date_range[1])] # Resample and plot df_filtered = df_filtered[variable].resample(resample_dict[resample_time]).sum() st.line_chart(df_filtered) # Dropdown to select variables for scatter plot scatter_variable_1 = st.selectbox('Select Variable 1 for Scatter Plot', ['num_comments', 'score', 'cosine']) scatter_variable_2 = st.selectbox('Select Variable 2 for Scatter Plot', ['num_comments', 'score', 'cosine']) # Plot scatter chart st.write(f"Scatter Plot of {scatter_variable_1} vs {scatter_variable_2}") # Plot scatter chart scatter_fig = px.scatter(df, x=scatter_variable_1, y=scatter_variable_2) st.plotly_chart(scatter_fig)