JulianHame's picture
Update app.py
44be565
raw
history blame contribute delete
No virus
2.35 kB
import streamlit as st
from transformers import pipeline
import tensorflow as tf
import numpy as np
import pandas as pd
from tensorflow.keras.layers import TextVectorization
from tensorflow import keras
model = tf.keras.models.load_model('toxicity_model.h5')
dataset = pd.read_csv('train.csv')
comments = dataset['comment_text']
vectorizer = TextVectorization(max_tokens = 2500000,
output_sequence_length=1800,
output_mode='int')
vectorizer.adapt(comments.values)
st.title('Toxicity Classifier')
st.header('Write a message here:')
text = st.text_area('The toxicity of the message will be evaluated.',
value = "You're fucking ugly.")
input_str = vectorizer(text)
res = model.predict(np.expand_dims(input_str,0))
classification = res[0].tolist()
toxicity = classification[0]
toxicity_severe = classification[1]
obscene = classification[2]
threat = classification[3]
insult = classification[4]
identity_hate = classification[5]
highest_class = "Severe toxicity"
highest_class_rating = toxicity_severe
if(obscene > highest_class_rating):
highest_class = "Obscenity"
highest_class_rating = obscene
if(threat > highest_class_rating):
highest_class = "Threat"
highest_class_rating = threat
if(insult > highest_class_rating):
highest_class = "Insult"
highest_class_rating = insult
if(identity_hate > highest_class_rating):
highest_class = "Identity hate"
highest_class_rating = identity_hate
st.write("---")
st.write("Overall toxicity rating: " +str(toxicity))
st.write("---")
st.write("Classifications:")
if(toxicity_severe > 0.5):
st.write("Severely toxic - " +str(toxicity_severe))
if(obscene > 0.5):
st.write("Obscene - " +str(obscene))
if(threat > 0.5):
st.write("Threat - " +str(threat))
if(insult > 0.5):
st.write("Insult - " +str(insult))
if(identity_hate > 0.5):
st.write("Identity hate - " +str(identity_hate))
st.write("---")
st.write("Invalid classifications:")
if(toxicity_severe <= 0.5):
st.write("Severely toxic - " +str(toxicity_severe))
if(obscene <= 0.5):
st.write("Obscene - " +str(obscene))
if(threat <= 0.5):
st.write("Threat - " +str(threat))
if(insult <= 0.5):
st.write("Insult - " +str(insult))
if(identity_hate <= 0.5):
st.write("Identity hate - " +str(identity_hate))