Upload model.py
Browse files
model.py
ADDED
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Importing necessary libraries.
|
2 |
+
import streamlit as st
|
3 |
+
st.set_page_config(page_title="Monkeypox misinformation detector",
|
4 |
+
page_icon=":lion:",
|
5 |
+
layout="wide",
|
6 |
+
initial_sidebar_state="auto",
|
7 |
+
menu_items=None)
|
8 |
+
import tweepy as tw
|
9 |
+
import textacy
|
10 |
+
from textacy import preprocessing
|
11 |
+
import emoji
|
12 |
+
import pandas as pd
|
13 |
+
import numpy as np
|
14 |
+
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
|
15 |
+
import tensorflow as tf
|
16 |
+
import datetime as dt
|
17 |
+
import time
|
18 |
+
import copy
|
19 |
+
import altair as alt
|
20 |
+
|
21 |
+
|
22 |
+
@st.experimental_singleton(show_spinner=False)
|
23 |
+
def load_model():
|
24 |
+
"""
|
25 |
+
This function loads the fine-tuned HuggingFace model and caches
|
26 |
+
it (using the experimental_singleton decorator) to improve
|
27 |
+
computation times.
|
28 |
+
|
29 |
+
Parameters: none.
|
30 |
+
Returns: HuggingFace transformer model.
|
31 |
+
"""
|
32 |
+
|
33 |
+
model = TFAutoModelForSequenceClassification.from_pretrained("smcrone/monkeypox-misinformation")
|
34 |
+
model.compile(
|
35 |
+
optimizer=tf.keras.optimizers.Adam(learning_rate=5e-6),
|
36 |
+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
37 |
+
metrics=tf.keras.metrics.SparseCategoricalAccuracy())
|
38 |
+
return model
|
39 |
+
|
40 |
+
|
41 |
+
@st.experimental_singleton(show_spinner=False)
|
42 |
+
def load_tokenizer():
|
43 |
+
"""
|
44 |
+
This function loads a tokenizer for the transformer model and caches
|
45 |
+
it (using the experimental_singleton decorator) to improve
|
46 |
+
computation times.
|
47 |
+
|
48 |
+
Parameters: none.
|
49 |
+
Returns: tokenizer.
|
50 |
+
"""
|
51 |
+
|
52 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased",use_fast=False)
|
53 |
+
return tokenizer
|
54 |
+
|
55 |
+
|
56 |
+
@st.experimental_singleton(show_spinner=False)
|
57 |
+
def load_client():
|
58 |
+
"""
|
59 |
+
This function authenticates the Tweepy client and caches
|
60 |
+
the object (using the experimental_singleton decorator) to
|
61 |
+
improve computation times.
|
62 |
+
|
63 |
+
Parameters: none.
|
64 |
+
Returns: Tweepy client.
|
65 |
+
"""
|
66 |
+
|
67 |
+
bearer_token = st.secrets["bearer_token"]
|
68 |
+
client = tw.Client(bearer_token,wait_on_rate_limit=True)
|
69 |
+
return client
|
70 |
+
|
71 |
+
|
72 |
+
def dataframe_preprocessing(df_to_preprocess:pd.DataFrame):
|
73 |
+
"""
|
74 |
+
The program overall collects tweet data at two junctures: firstly
|
75 |
+
on provision of the initial tweet, and secondly if the classification
|
76 |
+
of the initial tweet prompts a review of the user's other recent tweets.
|
77 |
+
At both of these junctures certain preprocessing steps -- designed to
|
78 |
+
increase the intelligibility of text inputs to the model -- are identical,
|
79 |
+
so this function is designed to avoid the unnecessary repetition of this
|
80 |
+
code. The function takes a Pandas DataFrame for preprocessing and returns
|
81 |
+
the DataFrame, having executed certain preprocessing steps (e.g. removal
|
82 |
+
of emojis, normalization of whitespace, removal of columns, etc.)
|
83 |
+
|
84 |
+
Parameters: df_to_preprocess (DataFrame)
|
85 |
+
Returns: df_to_preprocess (DataFrame)
|
86 |
+
"""
|
87 |
+
|
88 |
+
# userlocation will not be in dataframe is user not supplied field. So, for time being, fill with none if it does not exist.
|
89 |
+
# !!! note: we will likely NOT use userlocation, so can remove this bit of code in later versions!!!
|
90 |
+
if 'userlocation' not in df_to_preprocess.columns:
|
91 |
+
df_to_preprocess['userlocation'] = 'None'
|
92 |
+
# Dropping redundant columns.
|
93 |
+
df_to_preprocess = df_to_preprocess.drop(labels=['public_metrics', 'userpublic_metrics'], axis=1)
|
94 |
+
# Stripping timezone info for export to Excel.
|
95 |
+
df_to_preprocess['created_at'] = df_to_preprocess['created_at'].dt.tz_localize(None)
|
96 |
+
df_to_preprocess['usercreated_at'] = df_to_preprocess['usercreated_at'].dt.tz_localize(None)
|
97 |
+
# Replacing URLs and emojis; normalizing bullet points, whitespace, etc.
|
98 |
+
for feature in ['text','userdescription','userlocation','userurl','username']:
|
99 |
+
df_to_preprocess[feature] = df_to_preprocess[feature].fillna('None').apply(str)
|
100 |
+
df_to_preprocess[feature] = df_to_preprocess[feature].apply(lambda x: textacy.preprocessing.replace.urls(text= x, repl= '_URL_'))
|
101 |
+
df_to_preprocess[feature] = df_to_preprocess[feature].apply(lambda x: emoji.demojize(x))
|
102 |
+
df_to_preprocess[feature] = df_to_preprocess[feature].apply(lambda x: textacy.preprocessing.normalize.bullet_points(text=x))
|
103 |
+
df_to_preprocess[feature] = df_to_preprocess[feature].apply(lambda x: textacy.preprocessing.normalize.quotation_marks(text=x))
|
104 |
+
df_to_preprocess[feature] = df_to_preprocess[feature].apply(lambda x: textacy.preprocessing.normalize.whitespace(text=x))
|
105 |
+
df_to_preprocess[feature] = df_to_preprocess[feature].replace('\n', ' ', regex=True).replace('\r', '', regex=True)
|
106 |
+
# Renaming columns (for greater model intelligibility).
|
107 |
+
df_to_preprocess.rename(columns={"userverified": "user is verified",
|
108 |
+
"userurl": "user has url",
|
109 |
+
"userdescription": "user description",
|
110 |
+
"usercreated_at": "user created at",
|
111 |
+
"followers_count": "followers count",
|
112 |
+
"following_count": "following count",
|
113 |
+
"tweet_count": "tweet count",
|
114 |
+
"userlocation": "user location"},
|
115 |
+
inplace=True)
|
116 |
+
# Making URL column binary.
|
117 |
+
df_to_preprocess['user has url'].replace({'_URL_': 'True', "": 'False'}, inplace=True)
|
118 |
+
# Adding some extra features.
|
119 |
+
df_to_preprocess['years since account created'] = df_to_preprocess['created_at'].dt.year.astype('Int64') - df_to_preprocess['user created at'].dt.year.astype('Int64')
|
120 |
+
df_to_preprocess['tweets per day'] = df_to_preprocess['tweet count']/((df_to_preprocess['created_at'] - df_to_preprocess['user created at']).dt.days)
|
121 |
+
df_to_preprocess['follower to following ratio'] = df_to_preprocess['followers count']/(df_to_preprocess['following count']+1)
|
122 |
+
# Returning processed DataFrame.
|
123 |
+
return df_to_preprocess
|
124 |
+
|
125 |
+
|
126 |
+
def feature_concatenation(dataframe_to_concatenate:pd.DataFrame,features:list):
|
127 |
+
"""
|
128 |
+
Our transformer model was fine-tuned on text input that combines
|
129 |
+
a number of fields in a single string. This function performs
|
130 |
+
the concatenation of these features, which in addition to dataframe
|
131 |
+
preprocessing, is a necessary preprocessing step. The final dataframe
|
132 |
+
consists of just two columns: one containing the concatenated text and
|
133 |
+
the other containing the number of retweets that the tweet received
|
134 |
+
(for use later on).
|
135 |
+
|
136 |
+
Parameters:
|
137 |
+
|
138 |
+
1. dataframe_to_concatenate (DataFrame): the df from which to take the features.
|
139 |
+
2. features (list of str): the features to concatenate.
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
|
143 |
+
1. finalDataFrame (DataFrame): the dataframe to be passed to the model.
|
144 |
+
"""
|
145 |
+
|
146 |
+
# Make copy of dataframe consisting only of specified features.
|
147 |
+
concatenated_dataframe = dataframe_to_concatenate[features].copy()
|
148 |
+
# Concatenate chosen features.
|
149 |
+
for i in features:
|
150 |
+
concatenated_dataframe[i] = concatenated_dataframe[i].name + ": " + concatenated_dataframe[i].astype(str)
|
151 |
+
concatenated_dataframe['combined'] = concatenated_dataframe[features].apply(lambda row: ' [SEP] '.join(row.values.astype(str)), axis=1)
|
152 |
+
final_concatenated_dataframe = pd.DataFrame({"combined":concatenated_dataframe['combined'],"retweets":dataframe_to_concatenate['retweet_count']})
|
153 |
+
# Return the final DataFrame.
|
154 |
+
return final_concatenated_dataframe
|
155 |
+
|
156 |
+
|
157 |
+
def classify_tweets(dataframe_to_classify:pd.DataFrame):
|
158 |
+
"""
|
159 |
+
This function takes a DataFrame of tweets which, having gone through
|
160 |
+
the necessary preprocessing steps, is ready to classify. The function
|
161 |
+
is called both for the initial classification of a single tweet and,
|
162 |
+
where necessary, the superspreader analysis of the user's recent tweets.
|
163 |
+
The function iterates through the DataFrame provided, tokenizing and
|
164 |
+
classifying each tweet, and assigning it to one of two lists within a
|
165 |
+
dictionary: 'goodPosts' (i.e. non-misleading posts) and 'badPosts (i.e.
|
166 |
+
misleading posts). The function then returns the dictionary, which for
|
167 |
+
each post includes the tweet itself, the predicted class, the confidence
|
168 |
+
of the prediction, and the number of retweets received by the post.
|
169 |
+
|
170 |
+
Parameters: dataframe_to_classify (DataFrame) -- the preprocessed
|
171 |
+
DataFrame of tweet(s).
|
172 |
+
|
173 |
+
Returns: tweet_dict (dict): a dictionary of classification results.
|
174 |
+
"""
|
175 |
+
|
176 |
+
# Storing classification results in a dictionary with two keys.
|
177 |
+
tweet_dict ={}
|
178 |
+
tweet_dict['goodPosts'] = []
|
179 |
+
tweet_dict['badPosts'] = []
|
180 |
+
# Iterate through each tweet string in the DataFrame provided.
|
181 |
+
for i in range(len(dataframe_to_classify['combined'])):
|
182 |
+
# First, tokenize the tweet.
|
183 |
+
tokenized_tweet = tokenizer(dataframe_to_classify['combined'].iloc[i],padding="max_length",truncation=True)
|
184 |
+
# Next, convert tweet to a format that TensorFlow will accept.
|
185 |
+
predict_dict = {}
|
186 |
+
for x,y in tokenized_tweet.items():
|
187 |
+
a = tf.convert_to_tensor(y, dtype=None, dtype_hint=None, name=None)
|
188 |
+
b = tf.reshape(a,[1,512])
|
189 |
+
predict_dict[x] = b
|
190 |
+
# Call model to predict tweet.
|
191 |
+
prediction = model(predict_dict,training=False)
|
192 |
+
# Take pred. class and confidence in pred. class
|
193 |
+
pred_class = np.argmax(np.array(tf.nn.softmax(prediction.logits)))
|
194 |
+
pred_conf = np.max(np.array(tf.nn.softmax(prediction.logits)))
|
195 |
+
# Construct a list of variables that we wish to store.
|
196 |
+
seq_to_append = [dataframe_to_classify['combined'].iloc[i],pred_class,pred_conf,dataframe_to_classify['retweets'].iloc[i]]
|
197 |
+
# Add list under appropriate dictionary key.
|
198 |
+
if pred_class == 1:
|
199 |
+
tweet_dict['badPosts'].append(seq_to_append)
|
200 |
+
elif pred_class == 0:
|
201 |
+
tweet_dict['goodPosts'].append(seq_to_append)
|
202 |
+
else:
|
203 |
+
print("Something went wrong.")
|
204 |
+
return
|
205 |
+
# Return the dictionary of results.
|
206 |
+
return tweet_dict
|
207 |
+
|
208 |
+
|
209 |
+
def get_user_tweets(user_id:str, days_to_go_back:int, client:tw.Client):
|
210 |
+
"""
|
211 |
+
If the initial tweet provided to the web app is classified as
|
212 |
+
misleading, then relevant tweets from the user must be gathered
|
213 |
+
in order to perform the superspreader calculation. This function
|
214 |
+
supports this process by collecting relevant user tweets, undertaking
|
215 |
+
the necessary preprocessing steps (with support from other functions),
|
216 |
+
and classifying the tweets using the classify_tweets function. It
|
217 |
+
then returns the dictionary of results produced by classify_tweets.
|
218 |
+
|
219 |
+
Parameters:
|
220 |
+
|
221 |
+
1. user_id (int|str): the user_id to be fed to Tweepy.
|
222 |
+
2. days_to_go_back (int): how many days' tweets to investigate.
|
223 |
+
3. client: the Tweepy client instantiated by load_client.
|
224 |
+
|
225 |
+
Returns:
|
226 |
+
|
227 |
+
1. user_tweets_classified (dict): model outputs for user tweets.
|
228 |
+
"""
|
229 |
+
|
230 |
+
# STAGE 1. FETCH USER TWEETS
|
231 |
+
|
232 |
+
# Converting days_to_go_back into variables that can be fed to Tweepy.
|
233 |
+
d = dt.datetime.today() - dt.timedelta(days=days_to_go_back)
|
234 |
+
year = str(d.year)
|
235 |
+
month = str(d.month)
|
236 |
+
if len(month) == 1:
|
237 |
+
month = '0'+month
|
238 |
+
day = str(d.day)
|
239 |
+
if len(day) == 1:
|
240 |
+
day = '0'+day
|
241 |
+
hour = str(d.hour)
|
242 |
+
if len(hour) == 1:
|
243 |
+
hour = '0'+hour
|
244 |
+
# Gathering tweets from user.
|
245 |
+
try:
|
246 |
+
tweets_we_want_to_check = tw.Paginator(client.get_users_tweets,
|
247 |
+
id = user_id,
|
248 |
+
end_time=None,
|
249 |
+
exclude=None,
|
250 |
+
expansions=['author_id'],
|
251 |
+
max_results=100,
|
252 |
+
media_fields=None,
|
253 |
+
pagination_token=None,
|
254 |
+
place_fields=None,
|
255 |
+
poll_fields=None,
|
256 |
+
since_id=None,
|
257 |
+
start_time='{}-{}-{}T{}:00:00Z'.format(year,month,day,hour),
|
258 |
+
tweet_fields=['author_id','created_at','public_metrics','source'],
|
259 |
+
until_id=None,
|
260 |
+
user_fields=['created_at','description','location','public_metrics','url','verified'],
|
261 |
+
user_auth=False,
|
262 |
+
limit=500)
|
263 |
+
except:
|
264 |
+
return "Something went wrong whilst performing superspreader analysis."
|
265 |
+
|
266 |
+
# STAGE 2. PREPROCESSING TWEET DATA
|
267 |
+
|
268 |
+
# Parsing response data into an intermediate form.
|
269 |
+
tweet_data_for_user = []
|
270 |
+
user_data_for_user = []
|
271 |
+
for page in tweets_we_want_to_check:
|
272 |
+
# Converting each set of tweet fields into a dict and appending to list.
|
273 |
+
for tweet in page.data:
|
274 |
+
result = dict(tweet)
|
275 |
+
tweet_data_for_user.append(result)
|
276 |
+
# Converting each set of user fields into a dict and appending to list.
|
277 |
+
for user in page.includes['users']:
|
278 |
+
result = dict(user)
|
279 |
+
user_data_for_user.append(result)
|
280 |
+
# Adding user fields to tweet fields.
|
281 |
+
for tweet in tweet_data_for_user:
|
282 |
+
for user in user_data_for_user:
|
283 |
+
for key, val in user.items():
|
284 |
+
newKey = "user"+key
|
285 |
+
tweet[newKey] = val
|
286 |
+
break
|
287 |
+
# Unpack and append any values that are dictionaries.
|
288 |
+
for tweet in tweet_data_for_user:
|
289 |
+
additional_values = {}
|
290 |
+
for key, val in tweet.items():
|
291 |
+
if type(val) == dict:
|
292 |
+
for subkey, subval in val.items():
|
293 |
+
additional_values[subkey] = subval
|
294 |
+
tweet.update(additional_values)
|
295 |
+
# Create a Pandas DataFrame to store the data.
|
296 |
+
user_df = pd.DataFrame(tweet_data_for_user)
|
297 |
+
# Perform additional preprocessing using dedicated function.
|
298 |
+
user_df = dataframe_preprocessing(user_df)
|
299 |
+
# Drop non-monkeypox related rows.
|
300 |
+
user_df['monkeypox'] = user_df['text'].str.contains('monkeypox|monkey pox|money pox', case=False, regex=True)
|
301 |
+
user_df.drop(user_df[user_df.monkeypox == False].index, inplace=True)
|
302 |
+
# Concatenating chosen features.
|
303 |
+
concatenated_df = feature_concatenation(user_df,['text'])
|
304 |
+
|
305 |
+
# STAGE 3. CALLING CLASSIFIER AND RETURNING RESULTS
|
306 |
+
|
307 |
+
# Calling classifier.
|
308 |
+
classified_tweets = classify_tweets(concatenated_df)
|
309 |
+
# Returning dictionary of classified tweets.
|
310 |
+
return classified_tweets
|
311 |
+
|
312 |
+
|
313 |
+
def on_receipt_of_tweet_query(request:str,client:tw.Client):
|
314 |
+
"""
|
315 |
+
This function defines what the app should do on receipt of a tweet
|
316 |
+
URL / ID from the end-user. It performs the following steps:
|
317 |
+
(i) formats the string submitted by the userinto a parsable form;
|
318 |
+
(ii) fetches data for the tweet using Tweepy; (iii) performs some
|
319 |
+
basic preprocessing on the data; (iv) calls dedicated preprocessing
|
320 |
+
functions to finish preprocessing the data; (v) calls the classifier
|
321 |
+
on the tweet; (vi) determines whether superspreader analysis is
|
322 |
+
needed (i.e. if tweet is classed as misleading); (vii) if so,
|
323 |
+
calls get_user_tweet function and calculates a superspreader score;
|
324 |
+
(viii) returns a tuple of data for the application to display.
|
325 |
+
|
326 |
+
Parameters:
|
327 |
+
|
328 |
+
1. request (str): the URL or ID provided by the end-user.
|
329 |
+
2. client: the Tweepy client instantiated by load_client.
|
330 |
+
|
331 |
+
Returns:
|
332 |
+
|
333 |
+
1. classified_tweet (dict): the metrics returned for the tweet by classify_tweets.
|
334 |
+
2. spreader_score (float): where applicable, a metric representing the
|
335 |
+
3. extent to which the user can be regarded as a superspreader of misinformation.
|
336 |
+
4. tweet_text (str): the text of the tweet queried by the end-user.
|
337 |
+
5. followers_count (int): the number of followers that the user has.
|
338 |
+
6. classified_user_tweets (dict): where applicable, the metrics returned by
|
339 |
+
7. get_user_tweets.
|
340 |
+
"""
|
341 |
+
|
342 |
+
# STAGE 1. FETCH DATA FOR REQUESTED TWEET
|
343 |
+
|
344 |
+
# If URL is provided by the end-user, strip out the tweet ID.
|
345 |
+
if '/' in request:
|
346 |
+
request = request.split('/')[-1]
|
347 |
+
# Collect tweet data -- interrupt if invalid input provided.
|
348 |
+
tweet = client.get_tweets(ids=request,
|
349 |
+
expansions=['author_id'],
|
350 |
+
media_fields=None,
|
351 |
+
place_fields=None,
|
352 |
+
poll_fields=None,
|
353 |
+
tweet_fields=['author_id','created_at','public_metrics','source'],
|
354 |
+
user_fields=['created_at','description','location','public_metrics','url','verified'],
|
355 |
+
user_auth=False)
|
356 |
+
|
357 |
+
# STAGE 2. PREPROCESSING OF TWEET DATA
|
358 |
+
|
359 |
+
# Create dictionaries out of the tweet and user data.
|
360 |
+
for i in tweet.data:
|
361 |
+
tweet_fields = dict(i)
|
362 |
+
for i in tweet.includes['users']:
|
363 |
+
user_fields = dict(i)
|
364 |
+
# Add the data from the user dict to the tweet dict.
|
365 |
+
for key, val in user_fields.items():
|
366 |
+
newKey = "user"+key
|
367 |
+
tweet_fields[newKey] = val
|
368 |
+
# Unpack any values which are themselves dictionaries.
|
369 |
+
additional_values = {}
|
370 |
+
for key, val in tweet_fields.items():
|
371 |
+
if type(val) == dict:
|
372 |
+
for subkey, subval in val.items():
|
373 |
+
additional_values[subkey] = subval
|
374 |
+
tweet_fields.update(additional_values)
|
375 |
+
# Convert everything to a DataFrame.
|
376 |
+
tweet_df = pd.DataFrame(tweet_fields,index=[0])
|
377 |
+
# Store the raw tweet text itself for later use.
|
378 |
+
tweet_text = tweet_df['text'][0]
|
379 |
+
# Store the followers count for later use.
|
380 |
+
followers_count = tweet_df['followers_count'][0]
|
381 |
+
# Preprocess the data using dedicated functions.
|
382 |
+
tweet_df = dataframe_preprocessing(tweet_df)
|
383 |
+
concatenated_tweet_df = feature_concatenation(tweet_df,['text'])
|
384 |
+
|
385 |
+
# STAGE 3. CALLING CLASSIFIER AND DETERMINING NEXT STEPS
|
386 |
+
|
387 |
+
# Call the classifier on the tweet.
|
388 |
+
classified_tweet = classify_tweets(concatenated_tweet_df)
|
389 |
+
# If the tweet is misleading, call get_user_tweets and calculate
|
390 |
+
# the user's superspreader score.
|
391 |
+
if len(classified_tweet['badPosts']) == 1:
|
392 |
+
# Fetch a dictionary of classified user tweets
|
393 |
+
classified_user_tweets = get_user_tweets(tweet_df['userid'][0],14,client=client)
|
394 |
+
# Calculate the total number of retweets for all misleading posts.
|
395 |
+
retweets_total = 0
|
396 |
+
for tweet in classified_user_tweets['badPosts']:
|
397 |
+
retweets_total += tweet[-1]
|
398 |
+
# Assign the p (post) value.
|
399 |
+
p = (0.21 * len(classified_user_tweets['badPosts'])) ** 1.13
|
400 |
+
# Assign the f (follower) value
|
401 |
+
f = (0.25 * (np.log10(followers_count+1))) ** 4.73
|
402 |
+
# Assign the r (retweet) value
|
403 |
+
r = (1.04 * (np.log10(retweets_total+1))) ** 0.96
|
404 |
+
# Calculate spreader_score and return a tuple of info.
|
405 |
+
spreader_score = max(((1 - (1/(max(1,p+f+r))))*100),1)
|
406 |
+
return classified_tweet, tweet_text, followers_count, classified_user_tweets, retweets_total, spreader_score,
|
407 |
+
# Otherwise, if tweet is not misleading, return the same info
|
408 |
+
# (excluding any superspreader related variables).
|
409 |
+
elif len(classified_tweet['goodPosts']) == 1:
|
410 |
+
return classified_tweet, tweet_text, followers_count, 0, 0, 0
|
411 |
+
# Contingency in case an error should unexpectedly occur.
|
412 |
+
else:
|
413 |
+
raise Exception("Something went wrong whilst processing tweet data.")
|
414 |
+
|
415 |
+
|
416 |
+
def webpage():
|
417 |
+
"""
|
418 |
+
This function structures the main page of the web app using the
|
419 |
+
conventions of Streamlit. It begins by loading the model, the tokenizer
|
420 |
+
and the Tweepy client using the functions dedicated to those tasks.
|
421 |
+
Each of these elements is then cached. The remaining content that the
|
422 |
+
function generates then depends mostly on the inputs provided by the
|
423 |
+
end-user.
|
424 |
+
|
425 |
+
Parameters: none.
|
426 |
+
Returns: nothing.
|
427 |
+
"""
|
428 |
+
|
429 |
+
# Create a container for displaying loading messages which will clear
|
430 |
+
# once the tokenizer, Tweepy client and transformer model have loaded.
|
431 |
+
loading_container = st.empty()
|
432 |
+
with loading_container.container():
|
433 |
+
global model
|
434 |
+
model = load_model()
|
435 |
+
global client
|
436 |
+
client = load_client()
|
437 |
+
global tokenizer
|
438 |
+
tokenizer = load_tokenizer()
|
439 |
+
loading_container.empty()
|
440 |
+
|
441 |
+
# Write header content (e.g. banner image, title, description).
|
442 |
+
st.image("monkeypox-small.jpg")
|
443 |
+
st.title("Monkeypox misinformation detector")
|
444 |
+
st.write("Use this tool to detect whether a tweet contains\
|
445 |
+
monkeypox misinformation and assess the extent to which its\
|
446 |
+
poster can be considered a misinformation superspreader.")
|
447 |
+
|
448 |
+
st.sidebar.subheader("About")
|
449 |
+
st.sidebar.write("This app has been developed using a\
|
450 |
+
[COVID-Twitter-BERT](https://huggingface.co/digitalepidemiologylab/covid-twitter-bert-v2)\
|
451 |
+
model fine-tuned on a monkeypox misinformation\
|
452 |
+
dataset. Users can learn more about the\
|
453 |
+
[model](https://www.bbc.co.uk/sport) on the\
|
454 |
+
HuggingFace model repository and can explore on\
|
455 |
+
Kaggle the [dataset](https://www.kaggle.com/datasets/stephencrone/monkeypox)\
|
456 |
+
on which the model was trained. Further\
|
457 |
+
[documentation](https://www.kaggle.com/datasets/stephencrone/monkeypox),\
|
458 |
+
as well as the source code for the app, can be\
|
459 |
+
found in the project's GitHub repository.")
|
460 |
+
|
461 |
+
st.sidebar.subheader("Contact")
|
462 |
+
st.sidebar.write("If you have any questions, comments or feedback\
|
463 |
+
regarding this app that are not answered by the\
|
464 |
+
supporting documentation for the underpinning\
|
465 |
+
dataset or transformer model, please feel free\
|
466 |
+
to contact the author at sgscrone@liverpool.ac.uk.")
|
467 |
+
|
468 |
+
# Provide a text box for user to enter tweet ID / URL.
|
469 |
+
tweet_to_check = st.text_input("Please provide a tweet URL or ID", key="name")
|
470 |
+
# If the string provided by the user is empty, do nothing.
|
471 |
+
if tweet_to_check != "":
|
472 |
+
# Otherwise, if string is not empty, try fetching tweet using function.
|
473 |
+
try:
|
474 |
+
classified_tweet, tweet_text, followers_count, classified_user_tweets, retweets_total, spreader_score = on_receipt_of_tweet_query(tweet_to_check,client)
|
475 |
+
st.markdown("""<hr style="height:1px;border:none;background-color:#a6a6a6; margin-top:16px; margin-bottom:20px;" /> """, unsafe_allow_html=True)
|
476 |
+
col1, col2 = st.columns(2)
|
477 |
+
# In left column, present tweet text.
|
478 |
+
col1.subheader("Tweet")
|
479 |
+
tweet_text = textacy.preprocessing.normalize.whitespace(tweet_text)
|
480 |
+
col1.markdown('<p style="background-color: #F0F2F6; padding: 8px 8px 8px 8px;">{}{}</p>'.format(tweet_text,type(tweet_text)),unsafe_allow_html=True)
|
481 |
+
# In right column, present tweet classification.
|
482 |
+
col2.subheader("Rating for this tweet")
|
483 |
+
if len(classified_tweet['goodPosts']) != 0:
|
484 |
+
# Format blue for not misinformation.
|
485 |
+
col2.markdown('<p style="color:White; background-color: #1661AD; text-align: center; font-size: 20px;">Not misinformation</p>',unsafe_allow_html=True)
|
486 |
+
col2.markdown('<p style="font-size: 40px; text-align: center;">{}</p>'.format(format(classified_tweet['goodPosts'][0][2],'.0%')), unsafe_allow_html=True)
|
487 |
+
col2.markdown('<p style="text-align: center;">confidence level</p>', unsafe_allow_html=True)
|
488 |
+
else:
|
489 |
+
# Format red for misinformation.
|
490 |
+
col2.markdown('<p style="color:White; background-color: #701B20; text-align: center; font-size: 20px;">Misinformation</p>',unsafe_allow_html=True)
|
491 |
+
col2.markdown('<p style="font-size: 40px; text-align: center;">{}</p>'.format(format(classified_tweet['badPosts'][0][2],'.0%')), unsafe_allow_html=True)
|
492 |
+
col2.markdown('<p style="text-align: center;">confidence level</p>', unsafe_allow_html=True)
|
493 |
+
# Add additional container to display superspreader analysis.
|
494 |
+
superspreader_container = st.container()
|
495 |
+
superspreader_container.subheader("Superspreader rating for this user")
|
496 |
+
# Plot the superspreader score as a bar chart.
|
497 |
+
score_to_plot = pd.DataFrame({"classified_tweet":["score"],"spreader_score":[spreader_score]})
|
498 |
+
bar = alt.Chart(score_to_plot).mark_bar().encode(alt.X('spreader_score:Q',scale=alt.Scale(domain=(0, 100)), axis=None), alt.Y('classified_tweet',axis=None)).properties(height=60)
|
499 |
+
if spreader_score > 10:
|
500 |
+
label = bar.mark_text(align='right',baseline='middle', dx=-10, color='white', fontSize=20).encode(text=alt.Text("spreader_score:Q", format=",.0f"))
|
501 |
+
else:
|
502 |
+
label = bar.mark_text(align='right',baseline='middle', dx=25, color='black', fontSize=20).encode(text=alt.Text("spreader_score:Q", format=",.0f"))
|
503 |
+
x = bar+label
|
504 |
+
x = x.configure_mark(color='#701B20')
|
505 |
+
superspreader_container.altair_chart(x, use_container_width=True)
|
506 |
+
# Display stats on which calculation was based.
|
507 |
+
superspreader_container.write("Based on the user's **{:,} followers** and the following **{} tweet(s)** published over the last two weeks, which together received **{:,} retweet(s)**.".format(followers_count,len(classified_user_tweets['badPosts']),retweets_total))
|
508 |
+
# And print offending tweets from user's recent history.
|
509 |
+
for i in range(len(classified_user_tweets['badPosts'])):
|
510 |
+
recent_tweet = classified_user_tweets['badPosts'][i][0]
|
511 |
+
recent_tweet = recent_tweet.split('text:')[-1]
|
512 |
+
superspreader_container.markdown('<p style="background-color: #F0F2F6; padding: 8px 8px 8px 8px;">{}</p>'.format(recent_tweet),unsafe_allow_html=True)
|
513 |
+
except:
|
514 |
+
st.error("Could not retrieve information for tweet. Please ensure you are supplying a valid tweet ID or URL.")
|
515 |
+
st.markdown("""<hr style="height:1px;border:none;background-color:#a6a6a6; margin-top:16px; margin-bottom:20px;" /> """, unsafe_allow_html=True)
|
516 |
+
|
517 |
+
if __name__ == "__main__":
|
518 |
+
webpage()
|