from typing import List import numpy as np import streamlit as st import tweepy from bokeh.models import ColumnDataSource, HoverTool from bokeh.palettes import Cividis256 as Pallete from bokeh.plotting import Figure, figure from bokeh.transform import factor_cmap from sklearn.manifold import TSNE from sentence_transformers import SentenceTransformer client = tweepy.Client(bearer_token=st.secrets["tw_bearer_token"]) model_to_use = { "English": "all-MiniLM-L12-v2", "Use all the ones you know (~15 lang)": "paraphrase-multilingual-MiniLM-L12-v2" } # Original implementation from: https://huggingface.co/spaces/edugp/embedding-lenses/blob/main/app.py SEED = 42 @st.cache(show_spinner=False, allow_output_mutation=True) def load_model(model_name: str) -> SentenceTransformer: embedder = model_name return SentenceTransformer(embedder) def embed_text(text: List[str], model: SentenceTransformer) -> np.ndarray: return model.encode(text) def get_tsne_embeddings( embeddings: np.ndarray, perplexity: int = 30, n_components: int = 2, init: str = "pca", n_iter: int = 5000, random_state: int = SEED ) -> np.ndarray: tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state) return tsne.fit_transform(embeddings) def draw_interactive_scatter_plot( texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str ) -> Figure: # Normalize values to range between 0-255, to assign a color for each value values = np.array(values) labels = np.array(labels) max_value = values.max() min_value = values.min() if max_value - min_value == 0: values_color = np.ones(len(values)) else: values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str) values_color_set = sorted(values_color) values_list = values.astype(str).tolist() values_set = sorted(values_list) labels_list = labels.astype(str).tolist() source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels_list)) hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")]) p = figure(plot_width=800, plot_height=800, tools=[hover]) p.circle("x", "y", size=10, source=source, fill_color=factor_cmap("label", palette=[Pallete[int(id_)] for id_ in values_color_set], factors=values_set)) p.axis.visible = False p.xgrid.grid_line_color = None p.ygrid.grid_line_color = None p.toolbar.logo = None return p # Up to here def generate_plot( df: List[str], labels: List[int], model: SentenceTransformer, ) -> Figure: with st.spinner(text="Embedding text..."): embeddings = embed_text(df, model) # encoded_labels = encode_labels(labels) encoded_labels = labels with st.spinner("Reducing dimensionality..."): embeddings_2d = get_tsne_embeddings(embeddings) plot = draw_interactive_scatter_plot( df, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels, labels, 'text', 'label' ) return plot st.title("Tweet-SNEst") st.write("Visualize tweets embeddings in 2D using colors for topics labels.") col1, col2 = st.columns(2) with col1: tw_user = st.text_input("Twitter handle", "huggingface") with col2: tw_sample = st.number_input("Maximum number of tweets to use", 1, 300, 100, 10) expected_lang = st.radio( "What language should be assumed to be found?", ('English', 'Use all the ones you know (~15 lang)'), 0 ) with st.spinner(text="Loading model..."): model = load_model(model_to_use[expected_lang]) usr = client.get_user(username=tw_user) # st.write(usr.data.id) if tw_user: with st.spinner(f"Getting to know the '{tw_user}'..."): tweets_objs = [] while tw_sample >= 100: current_sample = min(100, tw_sample) tweets_response = client.get_users_tweets(usr.data.id, max_results=current_sample) tweets_objs += tweets_response.data tw_sample -= current_sample if tw_sample > 0: tweets_response = client.get_users_tweets(usr.data.id, max_results=tw_sample) tweets_objs += tweets_response.data tweets_txt = [tweet.text for tweet in tweets_objs] labels = [0] * len(tweets_txt) # plot = generate_plot(df, text_column, label_column, sample, dimensionality_reduction_function, model) plot = generate_plot(tweets_txt, labels, model) st.bokeh_chart(plot)