Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
import torch | |
import pandas as pd | |
import random | |
classifiers = ['toxic', 'severe_toxic', 'obscene', | |
'threat', 'insult', 'identity_hate'] | |
def reset_scores(): | |
global scores_df | |
scores_df = pd.DataFrame(columns=['Comment'] + classifiers) | |
def get_score(model_base, text): | |
if model_base == "bert-base-cased": | |
model_dir = "./bert/_bert_model" | |
elif model_base == "distilbert-base-cased": | |
model_dir = "./distilbert/_distilbert_model" | |
else: | |
model_dir = "./roberta/_roberta_model" | |
model = AutoModelForSequenceClassification.from_pretrained(model_dir) | |
tokenizer = AutoTokenizer.from_pretrained(model_base) | |
inputs = tokenizer.encode_plus( | |
text, max_length=512, truncation=True, padding=True, return_tensors='pt') | |
outputs = model(**inputs) | |
predictions = torch.sigmoid(outputs.logits) | |
return predictions | |
st.title("Toxic Comment Classifier") | |
model_base = st.selectbox("Select a pretrained model", | |
["roberta-base", "bert-base-cased", "distilbert-base-cased"]) | |
text_input = st.text_input("Enter text for toxicity classification", | |
"") | |
submit_btn = st.button("Submit") | |
if submit_btn and text_input: | |
result = get_score(model_base, text_input) | |
df = pd.DataFrame([result[0].tolist()], columns=classifiers) | |
df = df.round(2) # Round the values to 2 decimal places | |
df = df.applymap(lambda x: '{:.0%}'.format(x)) | |
st.table(df) | |
test_df = pd.read_csv( | |
"./jigsaw-toxic-comment-classification-challenge/test.csv") | |
sample_df = test_df.sample(n=3) | |
reset_scores() | |
for index, row in sample_df.iterrows(): | |
result = get_score(model_base, row['comment_text']) | |
scores = result[0].tolist() | |
scores_df.loc[len(scores_df)] = [row['comment_text']] + scores | |
scores_df = scores_df.round(2) | |
st.subheader("Toxicity Scores for Random Comments") | |
if st.button("Refresh"): | |
reset_scores() | |
st.success("New tweets have been loaded!") | |
st.table(scores_df) | |