Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	| from typing import List | |
| import itertools | |
| import string | |
| import re | |
| import requests | |
| import tweepy | |
| import hdbscan | |
| import numpy as np | |
| import streamlit as st | |
| from gensim.utils import deaccent | |
| from bokeh.models import ColumnDataSource, HoverTool, Label, Legend | |
| from bokeh.palettes import Colorblind as Pallete | |
| from bokeh.palettes import Set3 as AuxPallete | |
| from bokeh.plotting import Figure, figure | |
| from bokeh.transform import factor_cmap | |
| from sklearn.manifold import TSNE | |
| from sentence_transformers import SentenceTransformer, util | |
| client = tweepy.Client(bearer_token=st.secrets["tw_bearer_token"]) | |
| model_to_use = { | |
| "English": "all-MiniLM-L6-v2", | |
| "Use all the ones you know (~15 lang)": "paraphrase-multilingual-MiniLM-L12-v2" | |
| } | |
| stopwords_list = requests.get("https://gist.githubusercontent.com/rg089/35e00abf8941d72d419224cfd5b5925d/raw/12d899b70156fd0041fa9778d657330b024b959c/stopwords.txt").content | |
| stopwords = set(stopwords_list.decode().splitlines()) | |
| def _remove_unk_chars(txt_list: List[str]): | |
| txt_list = [re.sub('\s+', ' ', tweet) for tweet in txt_list] | |
| txt_list = [re.sub("\'", "", tweet) for tweet in txt_list] | |
| txt_list = [deaccent(tweet).lower() for tweet in txt_list] | |
| return txt_list | |
| def _remove_urls(txt_list: List[str]): | |
| url_regex = re.compile( | |
| r'^(?:http|ftp)s?://' # http:// or https:// | |
| r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain... | |
| r'localhost|' #localhost... | |
| r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip | |
| r'(?::\d+)?' # optional port | |
| r'(?:/?|[/?]\S+)$', re.IGNORECASE) | |
| txt_list = [tweet.split(' ') for tweet in txt_list] | |
| return [' '.join([word for word in tweet if not bool(re.match(url_regex, word))]) for tweet in txt_list] | |
| def _remove_punctuation(txt_list: List[str]): | |
| punctuation = string.punctuation + 'ΒΏΒ‘|' | |
| txt_list = [tweet.split(' ') for tweet in txt_list] | |
| return [' '.join([word.translate(str.maketrans('', '', punctuation)) for word in tweet]) for tweet in txt_list] | |
| def _remove_stopwords(txt_list: List[str]): | |
| txt_list = [tweet.split(' ') for tweet in txt_list] | |
| return [' '.join([word for word in tweet if word not in stopwords]) for tweet in txt_list] | |
| preprocess_pipeline = [ | |
| _remove_unk_chars, | |
| _remove_urls, | |
| _remove_punctuation, | |
| _remove_stopwords, | |
| ] | |
| def preprocess(txt_list: str): | |
| for op in preprocess_pipeline: | |
| txt_list = op(txt_list) | |
| return txt_list | |
| # Original implementation from: https://huggingface.co/spaces/edugp/embedding-lenses/blob/main/app.py | |
| SEED = 42 | |
| def load_model(model_name: str) -> SentenceTransformer: | |
| embedder = model_name | |
| return SentenceTransformer(embedder) | |
| def embed_text(text: List[str], model: SentenceTransformer) -> np.ndarray: | |
| return model.encode(text) | |
| def get_tsne_embeddings( | |
| embeddings: np.ndarray, perplexity: int = 10, n_components: int = 2, init: str = "pca", n_iter: int = 5000, random_state: int = SEED | |
| ) -> np.ndarray: | |
| tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state) | |
| return tsne.fit_transform(embeddings) | |
| def draw_interactive_scatter_plot( | |
| texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str | |
| ) -> Figure: | |
| # Normalize values to range between 0-255, to assign a color for each value | |
| max_value = values.max() | |
| min_value = values.min() | |
| if max_value - min_value == 0: | |
| values_color = np.ones(len(values)) | |
| else: | |
| values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str) | |
| values_color_set = sorted(values_color) | |
| values_list = values.astype(str).tolist() | |
| values_set = sorted(values_list) | |
| source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels)) | |
| hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")]) | |
| n_colors = len(set(values_color_set)) | |
| if n_colors not in Pallete: | |
| Palette = AuxPallete | |
| p = figure(plot_width=800, plot_height=800, tools=[hover], title='2D visualization of tweets', background_fill_color="#fafafa") | |
| colors = factor_cmap("label", palette=[Pallete[n_colors][int(id_) + 1] for id_ in values_set], factors=values_set) | |
| p.add_layout(Legend(location='top_left', title='Topics keywords', background_fill_alpha=0.2), 'above') | |
| p.circle("x", "y", size=12, source=source, fill_alpha=0.4, line_color=colors, fill_color=colors, legend_group="original_label") | |
| p.axis.visible = False | |
| p.xgrid.grid_line_dash = "dashed" | |
| p.ygrid.grid_line_dash = "dashed" | |
| # p.xgrid.grid_line_color = None | |
| # p.ygrid.grid_line_color = None | |
| p.toolbar.logo = None | |
| # p.legend.location = "bottom_right" | |
| # p.legend.title = "Topics ID" | |
| # p.legend.background_fill_alpha = 0.25 | |
| # disclaimer = Label(x=0, y=0, x_units="screen", y_units="screen", | |
| # text_font_size="14px", text_color="gray", | |
| # text="Topic equals -1 means no topic was detected for such tweet") | |
| # p.add_layout(disclaimer, "below") | |
| return p | |
| # Up to here | |
| def generate_plot( | |
| tws: List[str], | |
| tws_cleaned: List[str], | |
| model: SentenceTransformer, | |
| tw_user: str | |
| ) -> Figure: | |
| with st.spinner(text=f"Trying to understand '{tw_user}' tweets... π€"): | |
| embeddings = embed_text(tws_cleaned, model) | |
| # encoded_labels = encode_labels(labels) | |
| cluster = hdbscan.HDBSCAN( | |
| min_cluster_size=3, | |
| metric='euclidean', | |
| cluster_selection_method='eom' | |
| ).fit(embeddings) | |
| encoded_labels = cluster.labels_ | |
| cluster_keyword = {} | |
| with st.spinner("Now trying to express them with my own words... π¬"): | |
| for label in set(encoded_labels): | |
| if label == -1: | |
| cluster_keyword[label] = 'Too diverse!' | |
| continue | |
| cluster_keyword[label] = [] | |
| cluster_tws = [] | |
| cluster_ixs = [] | |
| for ix, obs in enumerate(encoded_labels): | |
| if obs == label: | |
| cluster_tws.append(tws_cleaned[ix]) | |
| cluster_ixs.append(ix) | |
| cluster_words = [tw.split(' ') for tw in cluster_tws] | |
| cluster_words = list(set(itertools.chain.from_iterable(cluster_words))) | |
| # cluster_embeddings = embed_text(cluster_tws, model) | |
| cluster_embeddings = [embeddings[i] for i in cluster_ixs] | |
| cluster_embeddings_avg = np.mean(cluster_embeddings, axis=0) | |
| cluster_words_embeddings = embed_text(cluster_words, model) | |
| cluster_to_words_similarities = util.dot_score(cluster_embeddings_avg, cluster_words_embeddings) | |
| cluster_to_words_similarities = [(word_ix, similarity) for word_ix, similarity in enumerate(cluster_to_words_similarities[0])] | |
| cluster_to_words_similarities = sorted(cluster_to_words_similarities, key=lambda x: x[1], reverse=True) | |
| while len(cluster_keyword[label]) < 3: | |
| try: | |
| most_descriptive = cluster_to_words_similarities.pop(0) | |
| except IndexError: | |
| break | |
| cluster_keyword[label].append(cluster_words[most_descriptive[0]]) | |
| if len(cluster_keyword[label]) == 1: | |
| cluster_keyword[label] = cluster_keyword[label][0] | |
| elif len(cluster_keyword[label]) == 0: | |
| cluster_keyword[label] = '-' | |
| elif len(cluster_keyword[label]) > 1: | |
| cluster_keyword[label] = [word for word in cluster_keyword[label] if word != ''] | |
| cluster_keyword[label] = ', '.join(cluster_keyword[label]) | |
| encoded_labels_keywords = [cluster_keyword[encoded_label] for encoded_label in encoded_labels] | |
| embeddings_2d = get_tsne_embeddings(embeddings) | |
| plot = draw_interactive_scatter_plot( | |
| tws, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels, encoded_labels_keywords, 'Tweet', 'Topic' | |
| ) | |
| return plot | |
| st.title("Tweet-SNEst") | |
| st.write("Visualize tweets embeddings in 2D using colors for topics labels.") | |
| st.caption('Please beware this is using Twitter free version of their API and might be needed to wait sometimes.') | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| tw_user = st.text_input("Twitter handle", "huggingface") | |
| with col2: | |
| tw_sample = st.number_input("Maximum number of tweets to use", 1, 300, 100, 10) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| expected_lang = st.radio( | |
| "What language should be assumed to be found?", | |
| ('English', 'Use all the ones you know (~15 lang)'), | |
| 0 | |
| ) | |
| with col2: | |
| go_btn = st.button('Visualize π') | |
| with st.spinner(text="Loading brain... π§ "): | |
| try: | |
| model = load_model(model_to_use[expected_lang]) | |
| except FileNotFoundError: | |
| model = SentenceTransformer(model_to_use[expected_lang]) | |
| if go_btn and tw_user != '': | |
| tw_user = tw_user.replace(' ', '') | |
| usr = client.get_user(username=tw_user) | |
| with st.spinner(f"Getting to know the '{tw_user}'... π"): | |
| tweets_objs = [] | |
| while tw_sample >= 100: | |
| current_sample = min(100, tw_sample) | |
| tweets_response = client.get_users_tweets(usr.data.id, max_results=current_sample, exclude=['retweets', 'replies']) | |
| tweets_objs += tweets_response.data | |
| tw_sample -= current_sample | |
| if tw_sample > 0: | |
| tweets_response = client.get_users_tweets(usr.data.id, max_results=tw_sample, exclude=['retweets', 'replies']) | |
| tweets_objs += tweets_response.data | |
| tweets_txt = [tweet.text for tweet in tweets_objs] | |
| tweets_txt = list(set(tweets_txt)) | |
| tweets_txt_cleaned = preprocess(tweets_txt) | |
| plot = generate_plot(tweets_txt, tweets_txt_cleaned, model, tw_user) | |
| st.bokeh_chart(plot) | |
| elif go_btn and tw_user == '': | |
| st.warning('Twitter handler field is empty π') | 
