Twitter-keyword-analysis / inference.py
EmanuelRiquelme's picture
Upload 5 files
3efdb8d
raw
history blame contribute delete
No virus
5.13 kB
from keybert import KeyBERT
from sen_model import Sentiment
from sampling import sampling_inference
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import mplcyberpunk
from adjustText import adjust_text
class Keyword_oracle():
def __init__(self,file_name,
weight_rt_fav = [1,4],
noise_threshold = 75,
words_exp = ["user","http","rt","fav",'https'],
**kwargs
):
self.key_bert = KeyBERT()
self.file_name = file_name
self.keybert_args = kwargs
self.weight_rt_fav = weight_rt_fav
self.raw_tweets= sampling_inference(file_name).sampled_df()
self.noise_threshold = noise_threshold if kwargs['top_n'] == 1 else 90 if kwargs['top_n'] == 2 else 95
self.tweets = self.raw_tweets['Tweet']
self.retweet = self.raw_tweets['Retweet']
self.favs = self.raw_tweets['Favs']
self.sentiment_eval = self.__sentimient_eval__()
self.words_exp = words_exp
self.mined_tweets = self.__tweets_mined__()
self.denoised_df = self.__denoised_df__()
self.percentiles = self.__find_threshold__()
self.categorical = self.__categorical__()
def __sentimient_eval__(self):
return Sentiment(self.tweets)
def __tweets_mined__(self):
raw_keywords = self.key_bert.extract_keywords(self.tweets,
keyphrase_ngram_range = self.keybert_args['keyphrase_ngram_range'],
diversity = self.keybert_args['diversity'],
top_n = self.keybert_args['top_n']
)
key_words,engagement,acum_sents = [],[],[]
for keys,retweet,fav,sent in zip(raw_keywords,self.retweet,self.favs,self.sentiment_eval):
for key in keys:
if not set(key[0].split()).intersection(set(self.words_exp)):
key_words.append(key[0])
engagement.append(1+retweet/self.weight_rt_fav[0]+fav/self.weight_rt_fav[1])
acum_sents.append(sent+retweet/self.weight_rt_fav[0]*(sent)+fav/self.weight_rt_fav[1]*sent)
key_word_data = {
"Key": key_words,
'engagement': engagement,
'emotions overall':acum_sents
}
return pd.DataFrame(key_word_data).groupby(['Key'], as_index=False).sum()
def __denoised_df__(self):
df = self.mined_tweets
tweets = df['engagement']
percentile = np.percentile(tweets, self.noise_threshold)
return df[tweets > percentile].reset_index(drop=True)
def __find_threshold__(self):
df = self.mined_tweets
tweets = df['emotions overall']
top_threshold = self.noise_threshold
bottom_threshold = 100-top_threshold
while np.percentile(tweets,top_threshold) <= 0 and np.percentile(tweets,100-top_threshold):
try:
top_threshold +=5
bottom_threshold -= 5
except top_threshold == 95:
top_threshold,bottom_threshold = 0,0
bottom_threshold,top_threshold = np.percentile(tweets,bottom_threshold),np.percentile(tweets,top_threshold)
return bottom_threshold,top_threshold
def __categorical__(self):
df = self.denoised_df
tweets = df['emotions overall'].to_numpy()
categorical = ['neutral','positive','negative']
bottom_threshold,top_threshold = self.percentiles
pos = (tweets >= top_threshold) if top_threshold > 0 else np.zeros(tweets.shape[0])
neg = (tweets <= bottom_threshold)*-1 if bottom_threshold < 0 else np.zeros(tweets.shape[0])
numerical = pos+neg
return [categorical[index] for index in numerical.astype(int)]
def return_table(self):
self.denoised_df['Categorical'] = self.__categorical__()
return self.denoised_df.sort_values(by=['emotions overall'],ascending = False).reset_index(drop=True)
def plot(self):
df = self.denoised_df
plt.style.use("cyberpunk")
keys = df['Key']
x,y = df['engagement'],df['emotions overall']
fig, ax = plt.subplots()
ax.scatter(x, y)
text = [plt.text(x_value,y_value,key_value) for x_value,y_value,key_value in zip(x,y,keys)]
adjust_text(text)
bottom_threshold,top_threshold = self.percentiles
plt.axhline(bottom_threshold ,c= "red", marker='.', linestyle=':') if bottom_threshold < 0 else None
plt.axhline(top_threshold,c= "magenta", marker='.', linestyle=':') if top_threshold > 0 else None
plt.title(f"Denoised sentiment analysis of {self.file_name}")
plt.xlabel("Engagement")
plt.ylabel("Emotions Overall")
return fig
if __name__ == "__main__":
file_name ='Graham Potter'
Keyword_oracle = Keyword_oracle(file_name,
keyphrase_ngram_range = (1,2),
diversity=0.3,top_n=3)
Keyword_oracle.plot()
print(Keyword_oracle.return_table())