Oliver Li
milestone3
809559e
raw
history blame
4.2 kB
import streamlit as st
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
# Function to load the pre-trained model
def load_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
return tokenizer, model
# Streamlit app
st.title("Multi-label Toxicity Detection App")
st.write("Enter a text and select the fine-tuned model to get the toxicity analysis.")
# Input text
default_text = "I will kill you if you do not give me my pop tarts."
text = st.text_input("Enter your text:", value=default_text)
category = {'LABEL_0': 'toxic', 'LABEL_1': 'severe_toxic', 'LABEL_2': 'obscene', 'LABEL_3': 'threat', 'LABEL_4': 'insult', 'LABEL_5': 'identity_hate'}
# Model selection
model_options = {
"Olivernyu/finetuned_bert_base_uncased": {
"description": "This model detects different types of toxicity like threats, obscenity, insults, and identity-based hate in text.",
},
"distilbert-base-uncased-finetuned-sst-2-english": {
"labels": ["NEGATIVE", "POSITIVE"],
"description": "This model classifies text into positive or negative sentiment. It is based on DistilBERT and fine-tuned on the Stanford Sentiment Treebank (SST-2) dataset.",
},
"textattack/bert-base-uncased-SST-2": {
"labels": ["LABEL_0", "LABEL_1"],
"description": "This model classifies text into positive(LABEL_1) or negative(LABEL_0) sentiment. It is based on BERT and fine-tuned on the Stanford Sentiment Treebank (SST-2) dataset.",
},
"cardiffnlp/twitter-roberta-base-sentiment": {
"labels": ["LABEL_0", "LABEL_1", "LABEL_2"],
"description": "This model classifies tweets into negative (LABEL_0), neutral(LABEL_1), or positive(LABEL_2) sentiment. It is based on RoBERTa and fine-tuned on a large dataset of tweets.",
},
}
selected_model = st.selectbox("Choose a fine-tuned model:", model_options)
st.write("### Model Information")
st.write(f"**Description:** {model_options[selected_model]['description']}")
# Load the model and perform toxicity analysis
if st.button("Analyze"):
if not text:
st.write("Please enter a text.")
else:
with st.spinner("Analyzing toxicity..."):
if selected_model == "Olivernyu/finetuned_bert_base_uncased":
tokenizer, model = load_model(selected_model)
toxicity_detector = pipeline("text-classification", tokenizer=tokenizer, model=model)
outputs = toxicity_detector(text, top_k=2)
category_names = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
scores = [output["score"] for output in outputs[0]]
# Get the highest toxicity class and its probability
max_score_index = scores.index(max(scores))
highest_toxicity_class = category_names[max_score_index]
highest_probability = scores[max_score_index]
results = []
for item in outputs:
results.append((category[item['label']], item['score']))
# Create a table with the input text (or a portion of it), the highest toxicity class, and its probability
table_data = {
"Text (portion)": [text[:50]],
f"{results[0][0]}": results[0][1],
f"{results[1][0]}": results[1][1]
}
table_df = pd.DataFrame(table_data)
st.table(table_df)
else:
sentiment_pipeline = load_model(selected_model)
result = sentiment_pipeline(text)
st.write(f"Sentiment: {result[0]['label']} (confidence: {result[0]['score']:.2f})")
if result[0]['label'] in ['POSITIVE', 'LABEL_1'] and result[0]['score']> 0.9:
st.balloons()
elif result[0]['label'] in ['NEGATIVE', 'LABEL_0'] and result[0]['score']> 0.9:
st.error("Hater detected.")
else:
st.write("Enter a text and click 'Analyze' to perform toxicity analysis.")