import streamlit as st import torch from transformers import AutoModelForSequenceClassification, pipeline, AutoTokenizer, DistilBertForSequenceClassification, DistilBertTokenizerFast import pandas as pd import comments from random import randint import requests def predict_cyberbullying_probability(sentence, tokenizer, model): # Preprocess the input sentence inputs = tokenizer(sentence, padding='max_length', return_token_type_ids=False, return_attention_mask=True, truncation=True, max_length=512, return_tensors='pt') attention_mask = inputs['attention_mask'] inputs = inputs['input_ids'] with torch.no_grad(): # Forward pass outputs = model(inputs, attention_mask=attention_mask) probs = torch.sigmoid(outputs.logits.unsqueeze(1).flatten()) res = probs.numpy().tolist() return res # @st.cache def perform_cyberbullying_analysis(tweet): with st.spinner(text="loading model, wait until spinner ends..."): model = AutoModelForSequenceClassification.from_pretrained('kingsotn/finetuned_cyberbullying') tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased') df = pd.DataFrame({'comment': [tweet]}) list_probs = predict_cyberbullying_probability(tweet, tokenizer, model) for i, label in enumerate(labels[1:]): df[label] = list_probs[i] return df def perform_default_analysis(model_name): tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) clf = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, framework="pt") tweet = st.text_area(label="Enter Text:",value="I'm nice at ping pong") submitted = st.form_submit_button("Analyze") if submitted: #loading bar with st.spinner(text="loading..."): out = clf(tweet) st.json(out) if out[0]["label"] == "POSITIVE" or out[0]["label"] == "POS": st.balloons() # prompt = f"{basic_prompt} + \n\nThe user wrote a tweet that says: {tweet}, compliment them on how nice of a person they are! Remember try to be as cringe and awkard as possible!" # response = generator(prompt, max_length=1000)[0] st.success("nice tweet!") else: # prompt = f"{basic_prompt} + \n\nThe user wrote a tweet that says: {tweet}, tell them on how terrible of a person they are! Remember try to be as cringe and awkard as possible!" # response = generator(prompt, max_length=1000)[0] st.error("bad tweet!") # main --> st.title("Toxic Tweets Analyzer") st.write("💡 Toxic Tweets Analyzer is an app that helps you determine the likelihood of a tweet or any text being toxic, abusive or cyberbullying. The app offers different pre-trained models to choose from, each with their own strengths and limitations. kingsotn/finetuned_cyberbullying is a finetuned distilbert. It uses artificial intelligence to analyze the text you input and then calculates a probability score for each label: toxic, severe_toxic, obscene, threat, insult, and identity_hate. The scores range from 0 to 1, with 1 being the highest probability of that label being present in the tweet. The output is a table that shows the probability scores for each label, giving you an idea of the toxicity of the tweet. This can be helpful in identifying and preventing cyberbullying and other forms of online abuse.") image = "kanye_loves_tweet.jpg" st.image(image, use_column_width=True) labels = ['comment', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] with st.form("my_form"): #select model model_name = st.selectbox("Enter a text and select a pre-trained model to get the sentiment analysis", ["kingsotn/finetuned_cyberbullying", "distilbert-base-uncased-finetuned-sst-2-english", "finiteautomata/bertweet-base-sentiment-analysis", "distilbert-base-uncased"]) if model_name == "kingsotn/finetuned_cyberbullying": default = "I'm not even going to lie to you. I love me so much right now." tweet = st.text_area(label="Enter Text:",value=default) submitted = st.form_submit_button("Analyze textbox") random = st.form_submit_button("Get a random 😈😈😈 tweet (warning!!)") kanye = st.form_submit_button("Get a ye quote 🐻🎤🎧🎶") if random: tweet = comments.comments[randint(0, 354)] st.write(tweet) submitted = True if kanye: response = requests.get('https://api.kanye.rest/') if response.status_code == 200: data = response.json() tweet = data['quote'] else: st.error("Error getting Kanye quote | status code: " + str(response.status_code)) st.write(tweet) submitted = True if submitted: df = perform_cyberbullying_analysis(tweet) st.table(df) else: perform_default_analysis(model_name)