tt-ai / app.py
kingsotn's picture
Update app.py
412cc43 verified
import streamlit as st
import torch
from transformers import AutoModelForSequenceClassification, pipeline, AutoTokenizer, DistilBertForSequenceClassification, DistilBertTokenizerFast
import pandas as pd
import comments
from random import randint
import requests
def predict_cyberbullying_probability(sentence, tokenizer, model):
# Preprocess the input sentence
inputs = tokenizer(sentence, padding='max_length', return_token_type_ids=False, return_attention_mask=True, truncation=True, max_length=512, return_tensors='pt')
attention_mask = inputs['attention_mask']
inputs = inputs['input_ids']
with torch.no_grad():
# Forward pass
outputs = model(inputs, attention_mask=attention_mask)
probs = torch.sigmoid(outputs.logits.unsqueeze(1).flatten())
res = probs.numpy().tolist()
return res
# @st.cache
def perform_cyberbullying_analysis(tweet):
with st.spinner(text="loading model, wait until spinner ends..."):
model = AutoModelForSequenceClassification.from_pretrained('kingsotn/finetuned_cyberbullying')
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
df = pd.DataFrame({'comment': [tweet]})
list_probs = predict_cyberbullying_probability(tweet, tokenizer, model)
for i, label in enumerate(labels[1:]):
df[label] = list_probs[i]
return df
def perform_default_analysis(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
clf = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, framework="pt")
tweet = st.text_area(label="Enter Text:",value="I'm nice at ping pong")
submitted = st.form_submit_button("Analyze")
if submitted:
#loading bar
with st.spinner(text="loading..."):
out = clf(tweet)
st.json(out)
if out[0]["label"] == "POSITIVE" or out[0]["label"] == "POS":
st.balloons()
# prompt = f"{basic_prompt} + \n\nThe user wrote a tweet that says: {tweet}, compliment them on how nice of a person they are! Remember try to be as cringe and awkard as possible!"
# response = generator(prompt, max_length=1000)[0]
st.success("nice tweet!")
else:
# prompt = f"{basic_prompt} + \n\nThe user wrote a tweet that says: {tweet}, tell them on how terrible of a person they are! Remember try to be as cringe and awkard as possible!"
# response = generator(prompt, max_length=1000)[0]
st.error("bad tweet!")
# main -->
st.title("Toxic Tweets Analyzer")
st.write("πŸ’‘ Toxic Tweets Analyzer uses AI with kingsotn/finetuned_cyberbullying (distilbert) to score tweets for toxicity, threat, and insult.")
image = "kanye_loves_tweet.jpg"
st.image(image, use_column_width=True)
labels = ['comment', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
with st.form("my_form"):
#select model
model_name = st.selectbox("Enter a text and select a pre-trained model to get the sentiment analysis", ["kingsotn/finetuned_cyberbullying", "distilbert-base-uncased-finetuned-sst-2-english", "finiteautomata/bertweet-base-sentiment-analysis", "distilbert-base-uncased"])
if model_name == "kingsotn/finetuned_cyberbullying":
default = "I'm not even going to lie to you. I love me so much right now."
tweet = st.text_area(label="Enter Text:",value=default)
submitted = st.form_submit_button("Analyze textbox")
random = st.form_submit_button("Get a random 😈😈😈 tweet (warning!!)")
kanye = st.form_submit_button("Get a ye quote 🐻🎀🎧🎢")
if random:
tweet = comments.comments[randint(0, 354)]
st.write(tweet)
submitted = True
if kanye:
response = requests.get('https://api.kanye.rest/')
if response.status_code == 200:
data = response.json()
tweet = data['quote']
else:
st.error("Error getting Kanye quote | status code: " + str(response.status_code))
st.write(tweet)
submitted = True
if submitted:
df = perform_cyberbullying_analysis(tweet)
st.table(df)
else:
perform_default_analysis(model_name)