import streamlit as st from transformers import pipeline, DistilBertTokenizerFast st.title("Toxic Tweets") models = [ "notbhu/toxic-tweet-classifier", "distilbert-base-uncased-finetuned-sst-2-english", "cardiffnlp/twitter-roberta-base-sentiment", "Seethal/sentiment_analysis_generic_dataset", ] default_tweet = """🐰🌸🐣 Happy Easter 🌸🐰🐣! It's time to crack open some eggs πŸ₯š and celebrate with the Easter Bunny πŸ°πŸ‡. Hop πŸ‡ on over to church β›ͺ️ and get down on your knees πŸ§Žβ€β™‚οΈπŸ™ for some Easter blessings 🐰✝️🌷. Did you know that Jesus πŸ™πŸ’’ died and rose again πŸ’€πŸ™ŒπŸŒ…? It's a time for rejoicing πŸŽ‰ and enjoying the company of loved ones πŸ‘¨β€πŸ‘©β€πŸ‘§β€πŸ‘¦. So put on your Sunday best πŸ‘— and get ready to hunt πŸ•΅οΈβ€β™€οΈ for some Easter treats 🍫πŸ₯šπŸ­. Happy Easter, bunnies πŸ°πŸ‘―β€β™€οΈ! Don't forget to spread the love ❀️ and send this message to your favorite bunnies πŸ’ŒπŸ‡. """ st.image( "https://www.gannett-cdn.com/presto/2022/04/12/USAT/3a93e183-d87d-493a-97a9-cf75fb7b9d18-AP_Pennsylvania_Easter.jpg" ) tweet = st.text_area("Enter a tweet", value=default_tweet) model = st.selectbox("Select a model", models) button = st.button("Predict") def getLabel(label, model): labels = { "notbhu/toxic-tweet-classifier": { "LABEL_0": "toxic", "LABEL_1": "severe_toxic", "LABEL_2": "obscene", "LABEL_3": "threat", "LABEL_4": "insult", "LABEL_5": "identity_hate", }, "distilbert-base-uncased-finetuned-sst-2-english": { "POSITIVE": "POSITIVE", "NEGATIVE": "NEGATIVE", }, "cardiffnlp/twitter-roberta-base-sentiment": { "LABEL_0": "NEGATIVE", "LABEL_1": "NEUTRAL", "LABEL_2": "POSITIVE", }, "Seethal/sentiment_analysis_generic_dataset": { "LABEL_0": "NEGATIVE", "LABEL_1": "POSITIVE", }, } return labels[model][label] def predict(tweet, model): with st.spinner("Predicting..."): tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased") classifier = pipeline(model=model, tokenizer=tokenizer) try: result = classifier(tweet) label = result[0]["label"] score = result[0]["score"] label = getLabel(label, model) if label == "POSITIVE": st.balloons() st.info(f"Label: {label} \n\n Score: {score}") except Exception as e: st.error("Something went wrong") st.error(e) if button: if not tweet: st.warning("Please enter a tweet") else: predict(tweet, model) elif tweet: predict(tweet, model)