import streamlit as st import plotly.express as px import torch from torch import nn from transformers import AutoTokenizer, AutoModelForSequenceClassification option = st.selectbox("Select a toxicity analysis model:", ("RoBERTa", "DistilBERT", "XLM-RoBERTa")) defaultTxt = "I hate you cancerous insects so much" txt = st.text_area("Text to analyze", defaultTxt) st.button("Submit Text") # Load tokenizer and model weights, try to default to RoBERTa. # Huggingface does not support Python 3.10 match statements and I'm too lazy to implement an equivalent. if (option == "RoBERTa"): tokenizerPath = "s-nlp/roberta_toxicity_classifier" modelPath = "s-nlp/roberta_toxicity_classifier" neutralIndex = 0 toxicIndex = 1 elif (option == "DistilBERT"): tokenizerPath = "citizenlab/distilbert-base-multilingual-cased-toxicity" modelPath = "citizenlab/distilbert-base-multilingual-cased-toxicity" neutralIndex = 1 toxicIndex = 0 elif (option == "XLM-RoBERTa"): tokenizerPath = "unitary/multilingual-toxic-xlm-roberta" modelPath = "unitary/multilingual-toxic-xlm-roberta" neutralIndex = 1 toxicIndex = 0 else: tokenizerPath = "s-nlp/roberta_toxicity_classifier" modelPath = "s-nlp/roberta_toxicity_classifier" neutralIndex = 0 toxicIndex = 1 tokenizer = AutoTokenizer.from_pretrained(tokenizerPath) model = AutoModelForSequenceClassification.from_pretrained(modelPath) # run encoding through model to get classification output # RoBERTA: [0]: neutral, [1]: toxic encoding = tokenizer.encode(txt, return_tensors='pt') result = model(encoding) # transform logit to get probabilities if (result.logits.size(dim=1) < 2): pad = (0, 1) result.logits = nn.functional.pad(result.logits, pad, "constant", 0) st.write(result) prediction = nn.functional.softmax(result.logits, dim=-1) neutralProb = prediction.data[0][neutralIndex] toxicProb = prediction.data[0][toxicIndex] # Expected returns from RoBERTa on default text: # Neutral: 0.0052 # Toxic: 0.9948 st.write("Classification Probabilities") st.write(f"{neutralProb:.4f} - NEUTRAL") st.write(f"{toxicProb:.4f} - TOXIC")