tweet-snest / app.py
wilmerags's picture
fix: Improve null keywords resulting for topics
ea87688
from typing import List
import itertools
import string
import re
import requests
import tweepy
import hdbscan
import numpy as np
import streamlit as st
from gensim.utils import deaccent
from bokeh.models import ColumnDataSource, HoverTool, Label, Legend
from bokeh.palettes import Colorblind as Pallete
from bokeh.palettes import Set3 as AuxPallete
from bokeh.plotting import Figure, figure
from bokeh.transform import factor_cmap
from sklearn.manifold import TSNE
from sentence_transformers import SentenceTransformer, util
client = tweepy.Client(bearer_token=st.secrets["tw_bearer_token"])
model_to_use = {
"English": "all-MiniLM-L6-v2",
"Use all the ones you know (~15 lang)": "paraphrase-multilingual-MiniLM-L12-v2"
}
stopwords_list = requests.get("https://gist.githubusercontent.com/rg089/35e00abf8941d72d419224cfd5b5925d/raw/12d899b70156fd0041fa9778d657330b024b959c/stopwords.txt").content
stopwords = set(stopwords_list.decode().splitlines())
def _remove_unk_chars(txt_list: List[str]):
txt_list = [re.sub('\s+', ' ', tweet) for tweet in txt_list]
txt_list = [re.sub("\'", "", tweet) for tweet in txt_list]
txt_list = [deaccent(tweet).lower() for tweet in txt_list]
return txt_list
def _remove_urls(txt_list: List[str]):
url_regex = re.compile(
r'^(?:http|ftp)s?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain...
r'localhost|' #localhost...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
txt_list = [tweet.split(' ') for tweet in txt_list]
return [' '.join([word for word in tweet if not bool(re.match(url_regex, word))]) for tweet in txt_list]
def _remove_punctuation(txt_list: List[str]):
punctuation = string.punctuation + 'ΒΏΒ‘|'
txt_list = [tweet.split(' ') for tweet in txt_list]
return [' '.join([word.translate(str.maketrans('', '', punctuation)) for word in tweet]) for tweet in txt_list]
def _remove_stopwords(txt_list: List[str]):
txt_list = [tweet.split(' ') for tweet in txt_list]
return [' '.join([word for word in tweet if word not in stopwords]) for tweet in txt_list]
preprocess_pipeline = [
_remove_unk_chars,
_remove_urls,
_remove_punctuation,
_remove_stopwords,
]
def preprocess(txt_list: str):
for op in preprocess_pipeline:
txt_list = op(txt_list)
return txt_list
# Original implementation from: https://huggingface.co/spaces/edugp/embedding-lenses/blob/main/app.py
SEED = 42
@st.cache(show_spinner=False, allow_output_mutation=True)
def load_model(model_name: str) -> SentenceTransformer:
embedder = model_name
return SentenceTransformer(embedder)
def embed_text(text: List[str], model: SentenceTransformer) -> np.ndarray:
return model.encode(text)
def get_tsne_embeddings(
embeddings: np.ndarray, perplexity: int = 10, n_components: int = 2, init: str = "pca", n_iter: int = 5000, random_state: int = SEED
) -> np.ndarray:
tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state)
return tsne.fit_transform(embeddings)
def draw_interactive_scatter_plot(
texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str
) -> Figure:
# Normalize values to range between 0-255, to assign a color for each value
max_value = values.max()
min_value = values.min()
if max_value - min_value == 0:
values_color = np.ones(len(values))
else:
values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str)
values_color_set = sorted(values_color)
values_list = values.astype(str).tolist()
values_set = sorted(values_list)
source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels))
hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")])
n_colors = len(set(values_color_set))
if n_colors not in Pallete:
Palette = AuxPallete
p = figure(plot_width=800, plot_height=800, tools=[hover], title='2D visualization of tweets', background_fill_color="#fafafa")
colors = factor_cmap("label", palette=[Pallete[n_colors][int(id_) + 1] for id_ in values_set], factors=values_set)
p.add_layout(Legend(location='top_left', title='Topics keywords', background_fill_alpha=0.2), 'above')
p.circle("x", "y", size=12, source=source, fill_alpha=0.4, line_color=colors, fill_color=colors, legend_group="original_label")
p.axis.visible = False
p.xgrid.grid_line_dash = "dashed"
p.ygrid.grid_line_dash = "dashed"
# p.xgrid.grid_line_color = None
# p.ygrid.grid_line_color = None
p.toolbar.logo = None
# p.legend.location = "bottom_right"
# p.legend.title = "Topics ID"
# p.legend.background_fill_alpha = 0.25
# disclaimer = Label(x=0, y=0, x_units="screen", y_units="screen",
# text_font_size="14px", text_color="gray",
# text="Topic equals -1 means no topic was detected for such tweet")
# p.add_layout(disclaimer, "below")
return p
# Up to here
def generate_plot(
tws: List[str],
tws_cleaned: List[str],
model: SentenceTransformer,
tw_user: str
) -> Figure:
with st.spinner(text=f"Trying to understand '{tw_user}' tweets... πŸ€”"):
embeddings = embed_text(tws_cleaned, model)
# encoded_labels = encode_labels(labels)
cluster = hdbscan.HDBSCAN(
min_cluster_size=3,
metric='euclidean',
cluster_selection_method='eom'
).fit(embeddings)
encoded_labels = cluster.labels_
cluster_keyword = {}
with st.spinner("Now trying to express them with my own words... πŸ’¬"):
for label in set(encoded_labels):
if label == -1:
cluster_keyword[label] = 'Too diverse!'
continue
cluster_keyword[label] = []
cluster_tws = []
cluster_ixs = []
for ix, obs in enumerate(encoded_labels):
if obs == label:
cluster_tws.append(tws_cleaned[ix])
cluster_ixs.append(ix)
cluster_words = [tw.split(' ') for tw in cluster_tws]
cluster_words = list(set(itertools.chain.from_iterable(cluster_words)))
# cluster_embeddings = embed_text(cluster_tws, model)
cluster_embeddings = [embeddings[i] for i in cluster_ixs]
cluster_embeddings_avg = np.mean(cluster_embeddings, axis=0)
cluster_words_embeddings = embed_text(cluster_words, model)
cluster_to_words_similarities = util.dot_score(cluster_embeddings_avg, cluster_words_embeddings)
cluster_to_words_similarities = [(word_ix, similarity) for word_ix, similarity in enumerate(cluster_to_words_similarities[0])]
cluster_to_words_similarities = sorted(cluster_to_words_similarities, key=lambda x: x[1], reverse=True)
while len(cluster_keyword[label]) < 3:
try:
most_descriptive = cluster_to_words_similarities.pop(0)
except IndexError:
break
cluster_keyword[label].append(cluster_words[most_descriptive[0]])
if len(cluster_keyword[label]) == 1:
cluster_keyword[label] = cluster_keyword[label][0]
elif len(cluster_keyword[label]) == 0:
cluster_keyword[label] = '-'
elif len(cluster_keyword[label]) > 1:
cluster_keyword[label] = [word for word in cluster_keyword[label] if word != '']
cluster_keyword[label] = ', '.join(cluster_keyword[label])
encoded_labels_keywords = [cluster_keyword[encoded_label] for encoded_label in encoded_labels]
embeddings_2d = get_tsne_embeddings(embeddings)
plot = draw_interactive_scatter_plot(
tws, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels, encoded_labels_keywords, 'Tweet', 'Topic'
)
return plot
st.title("Tweet-SNEst")
st.write("Visualize tweets embeddings in 2D using colors for topics labels.")
st.caption('Please beware this is using Twitter free version of their API and might be needed to wait sometimes.')
col1, col2 = st.columns(2)
with col1:
tw_user = st.text_input("Twitter handle", "huggingface")
with col2:
tw_sample = st.number_input("Maximum number of tweets to use", 1, 300, 100, 10)
col1, col2 = st.columns(2)
with col1:
expected_lang = st.radio(
"What language should be assumed to be found?",
('English', 'Use all the ones you know (~15 lang)'),
0
)
with col2:
go_btn = st.button('Visualize πŸš€')
with st.spinner(text="Loading brain... 🧠"):
try:
model = load_model(model_to_use[expected_lang])
except FileNotFoundError:
model = SentenceTransformer(model_to_use[expected_lang])
if go_btn and tw_user != '':
tw_user = tw_user.replace(' ', '')
usr = client.get_user(username=tw_user)
with st.spinner(f"Getting to know the '{tw_user}'... πŸ”"):
tweets_objs = []
while tw_sample >= 100:
current_sample = min(100, tw_sample)
tweets_response = client.get_users_tweets(usr.data.id, max_results=current_sample, exclude=['retweets', 'replies'])
tweets_objs += tweets_response.data
tw_sample -= current_sample
if tw_sample > 0:
tweets_response = client.get_users_tweets(usr.data.id, max_results=tw_sample, exclude=['retweets', 'replies'])
tweets_objs += tweets_response.data
tweets_txt = [tweet.text for tweet in tweets_objs]
tweets_txt = list(set(tweets_txt))
tweets_txt_cleaned = preprocess(tweets_txt)
plot = generate_plot(tweets_txt, tweets_txt_cleaned, model, tw_user)
st.bokeh_chart(plot)
elif go_btn and tw_user == '':
st.warning('Twitter handler field is empty πŸ™„')