milestone-2 / app.py
nppmatt's picture
cleanup
1345311
import streamlit as st
import plotly.express as px
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification
option = st.selectbox("Select a toxicity analysis model:", ("RoBERTa", "DistilBERT", "XLM-RoBERTa"))
defaultTxt = "I hate you cancerous insects so much"
txt = st.text_area("Text to analyze", defaultTxt)
st.button("Submit Text")
# Load tokenizer and model weights, try to default to RoBERTa.
# Huggingface does not support Python 3.10 match statements and I'm too lazy to implement an equivalent.
if (option == "RoBERTa"):
tokenizerPath = "s-nlp/roberta_toxicity_classifier"
modelPath = "s-nlp/roberta_toxicity_classifier"
neutralIndex = 0
toxicIndex = 1
elif (option == "DistilBERT"):
tokenizerPath = "citizenlab/distilbert-base-multilingual-cased-toxicity"
modelPath = "citizenlab/distilbert-base-multilingual-cased-toxicity"
neutralIndex = 1
toxicIndex = 0
elif (option == "XLM-RoBERTa"):
tokenizerPath = "unitary/multilingual-toxic-xlm-roberta"
modelPath = "unitary/multilingual-toxic-xlm-roberta"
neutralIndex = 1
toxicIndex = 0
else:
tokenizerPath = "s-nlp/roberta_toxicity_classifier"
modelPath = "s-nlp/roberta_toxicity_classifier"
neutralIndex = 0
toxicIndex = 1
tokenizer = AutoTokenizer.from_pretrained(tokenizerPath)
model = AutoModelForSequenceClassification.from_pretrained(modelPath)
# Run encoding through model to get classification output.
# RoBERTA: [0]: neutral, [1]: toxic
encoding = tokenizer.encode(txt, return_tensors='pt')
result = model(encoding)
# Transform logit to get probabilities.
if (result.logits.size(dim=1) < 2):
pad = (0, 1)
result.logits = nn.functional.pad(result.logits, pad, "constant", 0)
prediction = nn.functional.softmax(result.logits, dim=-1)
neutralProb = prediction.data[0][neutralIndex]
toxicProb = prediction.data[0][toxicIndex]
# Expected returns from RoBERTa on default text:
# Neutral: 0.0052
# Toxic: 0.9948
st.write("Classification Probabilities")
st.write(f"{neutralProb:.4f} - NEUTRAL")
st.write(f"{toxicProb:.4f} - TOXIC")