JulianHame's picture
Added many comments for documentation purposes
65901e5
raw
history blame contribute delete
No virus
5.11 kB
# Importing required modules
import streamlit as st
from transformers import pipeline
import tensorflow as tf
import numpy as np
import pandas as pd
from tensorflow.keras.layers import TextVectorization
from tensorflow import keras
st.title('Toxic Tweet Classifier') # Header of application
# Allow user to choose model in dropdown menu
modelChoice = st.selectbox('Select a fine-tuned toxicity model to evaluate the tweets below.',
("Toxicity Model (Trained for 1 epoch)",
"Toxicity Model (Trained for 3 epochs)"))
model = tf.keras.models.load_model('toxicity_model_1_epoch.h5') # The 1-epoch model is chosen by default
if(modelChoice == "Toxicity Model (Trained for 3 epochs)"): # If the user changes their choice to 3-epochs, it is chosen
model = tf.keras.models.load_model('toxicity_model_3_epochs.h5')
dataset = pd.read_csv('train.csv') # Reads .csv of dataset that the models were trained on
tweets = pd.read_csv('tweets.csv') # Reads .csv of dataset (previously test.csv) to use as tweets
comments = dataset['comment_text'] # Training dataset is now referred to as "comments"
tweets = tweets['comment_text'] # Tweets dataset is now referred to as "tweets"
# Vectorizer characteristics
vectorizer = TextVectorization(max_tokens = 2500000, # Vocabulary size set to maximum of 2,500,000
output_sequence_length = 1800, # Truncate output's dimension to 1800
output_mode='int') # Outputs integer indices for split string tokens
vectorizer.adapt(comments.values) # Vectorize the comments from the training dataset
highest_classes = [] # Array to store highest-rated toxicity classes for all tweets
highest_class_ratings = [] # Array to store the highest class rating values for all tweets
table_tweets = [] # Array to store tweet contents to use in printing a table to the user
x = 0 # Index initialized as 0
for tweet in tweets: # For every tweet in the dataset of tweets
if(x < 33): # Restricts loop to first 33 tweets to prevent oversized output
if(len(tweet) < 450): # Filters out tweets that are oversized
input_str = vectorizer(tweet) # Input string is set to the vectorized tweet data
guess = model.predict(np.expand_dims(input_str,0)) # Predict classification values for each tweet
classification = guess[0].tolist() # Assign classification values to a list
# Assign classification values to their respective names
toxicity = classification[0]
toxicity_severe = classification[1]
obscene = classification[2]
threat = classification[3]
insult = classification[4]
identity_hate = classification[5]
highest_class = "Severe toxicity" # Set default highest class as "Severe toxicity"
highest_class_rating = toxicity_severe # Set default highest rating as severe toxicity's rating
# If obscenity has a higher rating, set the highest class and highest rating to it
if(obscene > highest_class_rating):
highest_class = "Obscenity"
highest_class_rating = obscene
# If threat has a higher rating, set the highest class and highest rating to it
if(threat > highest_class_rating):
highest_class = "Threat"
highest_class_rating = threat
# If insult has a higher rating, set the highest class and highest rating to it
if(insult > highest_class_rating):
highest_class = "Insult"
highest_class_rating = insult
# If identity hate has a higher rating, set the highest class and highest rating to it
if(identity_hate > highest_class_rating):
highest_class = "Identity hate"
highest_class_rating = identity_hate
highest_classes.append(highest_class) # Append array with the highest-rated class of the current tweet
highest_class_ratings.append(highest_class_rating) # Append array with the highest rating value of the current tweet
table_tweets.append(tweet) # Append array with contents of the current tweet
x = x + 1 # Increase index value by 1 to arrive at next tweet
# Organize Tweets, highest classes and highest class ratings arrays into a dictionary
data = {'Tweet': table_tweets,
'Highest Class': highest_classes,
'Probability': highest_class_ratings}
df = pd.DataFrame(data) # Create a pandas dataframe using the dictionary created above
st.dataframe(df) # Print out table of the dataframe to the user