Spaces:
Runtime error
Runtime error
import streamlit as st | |
import plotly.express as px | |
import torch | |
from torch import nn | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
option = st.selectbox("Select a toxicity analysis model:", ("RoBERTa", "DistilBERT", "XLM-RoBERTa")) | |
defaultTxt = "I hate you cancerous insects so much" | |
txt = st.text_area("Text to analyze", defaultTxt) | |
# Load tokenizer and model weights, try to default to RoBERTa. | |
# Huggingface does not support Python 3.10 match statements and I'm too lazy to implement an equivalent. | |
if (option == "RoBERTa"): | |
tokenizerPath = "s-nlp/roberta_toxicity_classifier" | |
modelPath = "s-nlp/roberta_toxicity_classifier" | |
neutralIndex = 0 | |
toxicIndex = 1 | |
elif (option == "DistilBERT"): | |
tokenizerPath = "citizenlab/distilbert-base-multilingual-cased-toxicity" | |
modelPath = "citizenlab/distilbert-base-multilingual-cased-toxicity" | |
neutralIndex = 1 | |
toxicIndex = 0 | |
elif (option == "XLM-RoBERTa"): | |
tokenizerPath = "unitary/multilingual-toxic-xlm-roberta" | |
modelPath = "unitary/multilingual-toxic-xlm-roberta" | |
neutralIndex = 0 | |
toxicIndex = 1 | |
else: | |
tokenizerPath = "s-nlp/roberta_toxicity_classifier" | |
modelPath = "s-nlp/roberta_toxicity_classifier" | |
neutralIndex = 0 | |
toxicIndex = 1 | |
tokenizer = AutoTokenizer.from_pretrained(tokenizerPath) | |
model = AutoModelForSequenceClassification.from_pretrained(modelPath) | |
# run encoding through model to get classification output | |
# RoBERTA: [0]: neutral, [1]: toxic | |
encoding = tokenizer.encode(txt, return_tensors='pt') | |
result = model(encoding) | |
# transform logit to get probabilities | |
prediction = nn.functional.softmax(result.logits, dim=-1) | |
neutralProb = prediction.data[0][neutralIndex] | |
toxicProb = prediction.data[0][toxicIndex] | |
# Expected returns from RoBERTa on default text: | |
# Neutral: 0.0052 | |
# Toxic: 0.9948 | |
st.write("Classification Probabilities") | |
st.write(f"{neutralProb:.4f} - NEUTRAL") | |
st.write(f"{toxicProb:.4f} - TOXIC") | |