import streamlit as st from transformers import pipeline import tensorflow as tf import numpy as np import pandas as pd from tensorflow.keras.layers import TextVectorization from tensorflow import keras model = tf.keras.models.load_model('toxicity_model.h5') dataset = pd.read_csv('train.csv') comments = dataset['comment_text'] vectorizer = TextVectorization(max_tokens = 2500000, output_sequence_length=1800, output_mode='int') vectorizer.adapt(comments.values) st.title('Toxicity Classifier') st.header('Write a message here:') text = st.text_area('The toxicity of the message will be evaluated.', value = "You're fucking ugly.") input_str = vectorizer(text) res = model.predict(np.expand_dims(input_str,0)) classification = res[0].tolist() toxicity = classification[0] toxicity_severe = classification[1] obscene = classification[2] threat = classification[3] insult = classification[4] identity_hate = classification[5] highest_class = "Severe toxicity" highest_class_rating = toxicity_severe if(obscene > highest_class_rating): highest_class = "Obscenity" highest_class_rating = obscene if(threat > highest_class_rating): highest_class = "Threat" highest_class_rating = threat if(insult > highest_class_rating): highest_class = "Insult" highest_class_rating = insult if(identity_hate > highest_class_rating): highest_class = "Identity hate" highest_class_rating = identity_hate st.write("---") st.write("Overall toxicity rating: " +str(toxicity)) st.write("---") st.write("Classifications:") if(toxicity_severe > 0.5): st.write("Severely toxic - " +str(toxicity_severe)) if(obscene > 0.5): st.write("Obscene - " +str(obscene)) if(threat > 0.5): st.write("Threat - " +str(threat)) if(insult > 0.5): st.write("Insult - " +str(insult)) if(identity_hate > 0.5): st.write("Identity hate - " +str(identity_hate)) st.write("---") st.write("Invalid classifications:") if(toxicity_severe <= 0.5): st.write("Severely toxic - " +str(toxicity_severe)) if(obscene <= 0.5): st.write("Obscene - " +str(obscene)) if(threat <= 0.5): st.write("Threat - " +str(threat)) if(insult <= 0.5): st.write("Insult - " +str(insult)) if(identity_hate <= 0.5): st.write("Identity hate - " +str(identity_hate))