tweetpie's picture
- updates for webm generation
d16bcc6
raw
history blame
4.57 kB
import time
import pandas as pd
import streamlit as st
from transformers import pipeline
from constants import tweet_generator_prompt, absa_prompt
# Initialize the model and tokenizer once, to avoid reloading them on each user interaction
@st.cache_resource
def load_model():
start = time.time()
classification_pipe = pipeline(
"text-classification", model="tweetpie/toxic-content-detector", top_k=None)
print(f"Time to load the classification model: {time.time() - start:.2f}s")
start = time.time()
absa_pipe = pipeline("text2text-generation", model="tweetpie/stance-aware-absa")
print(f"Time to load the absa model: {time.time() - start:.2f}s")
start = time.time()
tweet_generation_pipe = pipeline("text2text-generation", model="tweetpie/stance-directed-tweet-generator")
print(f"Time to load the tweet generation model: {time.time() - start:.2f}s")
return classification_pipe, absa_pipe, tweet_generation_pipe
# Set up the title
st.title("Toxic Content Classifier Dashboard")
# Top-level input for model selection
model_selection = st.selectbox(
"Select an ideology",
options=['Left', 'Right'],
index=0 # Default selection
)
# Layout for entities and aspects inputs
col1, col2 = st.columns(2)
with col1:
st.header("Entities")
pro_entities = st.text_input("Pro Entities", help="Enter pro entities separated by commas")
anti_entities = st.text_input("Anti Entities", help="Enter anti entities separated by commas")
neutral_entities = st.text_input("Neutral Entities", help="Enter neutral entities separated by commas")
with col2:
st.header("Aspects")
pro_aspects = st.text_input("Pro Aspects", help="Enter pro aspects separated by commas")
anti_aspects = st.text_input("Anti Aspects", help="Enter anti aspects separated by commas")
neutral_aspects = st.text_input("Neutral Aspects", help="Enter neutral aspects separated by commas")
# Generate button
generate_button = st.button("Generate and classify toxicity")
# Load the model
classifier, absa, generator = load_model()
# Process the input text and generate output
if generate_button:
with st.spinner('Generating the tweet...'):
# Call the model with the aspects inputs
prompt = tweet_generator_prompt.format(
ideology=model_selection.lower(),
pro_entities=pro_entities,
anti_entities=anti_entities,
neutral_entities=neutral_entities,
pro_aspects=pro_aspects,
anti_aspects=anti_aspects,
neutral_aspects=neutral_aspects
)
time.sleep(5)
generated_tweet = [{"generated_text": "the agricultural sector is the single biggest recipient of migrants workers rights groups argue . nearly 90 % of those who come to the us are denied employment due to discriminatory employment laws and safety standards ."}]
# print("Prompt: ", prompt)
# start = time.time()
# generated_tweet = generator(prompt, max_new_tokens=80, do_sample=True, num_return_sequences=3)
# print(f"Time to generate the tweet: {time.time() - start:.2f}s")
# Displaying the input and model's output
st.write(f"Generated Tweet: {generated_tweet[0]['generated_text']}")
with st.spinner('Generating the Stance-Aware ABSA output...'):
# Call the model with the aspects inputs
absa_output = absa(absa_prompt.format(generated_tweet=generated_tweet[0]['generated_text']))
print("ABSA Output: ", absa_output)
stances = [x.strip() for x in absa_output[0]['generated_text'].split(',')]
stances = [{
'Aspect': x.split(':')[0],
'Sentiment': x.split(':')[1]
} for x in stances]
stances_df = pd.DataFrame(stances)
st.write("Stance-Aware ABSA Output:")
st.table(stances_df)
with st.spinner('Classifying the toxicity...'):
# Call the model with the input text
model_output = classifier(generated_tweet[0]['generated_text'])
output = model_output[0]
st.write("Toxicity Classifier Output:")
for i in range(3):
if output[i]['label'] == 'LABEL_0':
st.write(f"Non-Toxic Content: {output[i]['score']*100:.1f}%")
# print(f"Non-Toxic Content: {output[i]['score']*100:.1f}%")
elif output[i]['label'] == 'LABEL_1':
st.write(f"Toxic Content: {output[i]['score']*100:.1f}%")
# print(f"Toxic Content: {output[i]['score']*100:.1f}%")
else:
continue