import random
from typing import AnyStr
import streamlit as st
from bs4 import BeautifulSoup
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import pipeline
import os
from transformers_interpret import SequenceClassificationExplainer
# Map model names to URLs
model_names_to_URLs = {
'ml6team/distilbert-base-dutch-cased-toxic-comments':
'https://huggingface.co/ml6team/distilbert-base-dutch-cased-toxic-comments',
'ml6team/robbert-dutch-base-toxic-comments':
'https://huggingface.co/ml6team/robbert-dutch-base-toxic-comments',
}
about_page_markdown = f"""# ๐คฌ Dutch Toxic Comment Detection Space
Made by [ML6](https://ml6.eu/).
Token attribution is performed using [transformers-interpret](https://github.com/cdpierse/transformers-interpret).
"""
regular_emojis = [
'๐', '๐', '๐ถ', '๐',
]
undecided_emojis = [
'๐คจ', '๐ง', '๐ฅธ', '๐ฅด', '๐คท',
]
potty_mouth_emojis = [
'๐ค', '๐ฟ', '๐ก', '๐คฌ', 'โ ๏ธ', 'โฃ๏ธ', 'โข๏ธ',
]
# Page setup
st.set_page_config(
page_title="Toxic Comment Detection Space",
page_icon="๐คฌ",
layout="centered",
initial_sidebar_state="auto",
menu_items={
'Get help': None,
'Report a bug': None,
'About': about_page_markdown,
}
)
# Model setup
@st.cache(allow_output_mutation=True,
suppress_st_warning=True,
show_spinner=False)
def load_pipeline(model_name):
with st.spinner('Loading model (this might take a while)...'):
toxicity_pipeline = pipeline(
'text-classification',
model=model_name,
tokenizer=model_name)
cls_explainer = SequenceClassificationExplainer(
toxicity_pipeline.model,
toxicity_pipeline.tokenizer)
return toxicity_pipeline, cls_explainer
# Auxiliary functions
def format_explainer_html(html_string):
"""Extract tokens with attribution-based background color."""
inside_token_prefix = '##'
soup = BeautifulSoup(html_string, 'html.parser')
p = soup.new_tag('p',
attrs={'style': 'color: black; background-color: white;'})
# Select token elements and remove model specific tokens
current_word = None
for token in soup.find_all('td')[-1].find_all('mark')[1:-1]:
text = token.font.text.strip()
if text.startswith(inside_token_prefix):
text = text[len(inside_token_prefix):]
else:
# Create a new span for each word (sequence of sub-tokens)
if current_word is not None:
p.append(current_word)
p.append(' ')
current_word = soup.new_tag('span')
token.string = text
token.attrs['style'] = f"{token.attrs['style']}; padding: 0.2em 0em;"
current_word.append(token)
# Add last word
p.append(current_word)
# Add left and right-padding to each word
for span in p.find_all('span'):
span.find_all('mark')[0].attrs['style'] = (
f"{span.find_all('mark')[0].attrs['style']}; padding-left: 0.2em;")
span.find_all('mark')[-1].attrs['style'] = (
f"{span.find_all('mark')[-1].attrs['style']}; padding-right: 0.2em;")
return p
def list_all_filenames() -> list:
filenames = []
for file in os.listdir('./sample-articles/'):
if file.endswith('.txt'):
filenames.append(file.replace('.txt', ''))
return filenames
def fetch_file_contents(filename: str) -> AnyStr:
with open(f'./sample-terms-and-conditions/{filename.lower()}.txt', 'r') as f:
data = f.read()
return data
def classify_comment(comment, selected_model):
"""Classify the given comment and augment with additional information."""
toxicity_pipeline, cls_explainer = load_pipeline(selected_model)
result = toxicity_pipeline(comment)[0]
result['model_name'] = selected_model
# Add explanation
result['word_attribution'] = cls_explainer(comment, class_name="non-toxic")
result['visualitsation_html'] = cls_explainer.visualize()._repr_html_()
result['tokens_with_background'] = format_explainer_html(
result['visualitsation_html'])
# Choose emoji reaction
label, score = result['label'], result['score']
if label == 'toxic' and score > 0.1:
emoji = random.choice(potty_mouth_emojis)
elif label in ['non_toxic', 'non-toxic'] and score > 0.1:
emoji = random.choice(regular_emojis)
else:
emoji = random.choice(undecided_emojis)
result.update({'text': comment, 'emoji': emoji})
# Add result to session
st.session_state.results.append(result)
# Start session
if 'results' not in st.session_state:
st.session_state.results = []
# Page
st.title('๐คฌ Dutch Toxic Comment Detection')
st.markdown("""This demo showcases two Dutch toxic comment detection models.""")
# Introduction
st.markdown(f"""Both models were trained using a sequence classification task on a translated [Jigsaw Toxicity dataset](https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge) which contains toxic online comments.
The first model is a fine-tuned multilingual [DistilBERT](https://huggingface.co/distilbert-base-multilingual-cased) model whereas the second is a fine-tuned Dutch RoBERTa-based model called [RobBERT](https://huggingface.co/pdelobelle/robbert-v2-dutch-base).""")
st.markdown(f"""For a more comprehensive overview of the models check out their model card on ๐ค Model Hub: [distilbert-base-dutch-toxic-comments]({model_names_to_URLs['ml6team/distilbert-base-dutch-cased-toxic-comments']}) and [RobBERT-dutch-base-toxic-comments]({model_names_to_URLs['ml6team/robbert-dutch-base-toxic-comments']}).
""")
st.markdown("""Enter a comment that you want to classify below. The model will determine the probability that it is toxic and highlights how much each token contributes to its decision:
red
tokens indicate toxicity whereas
green
tokens indicate the opposite.
Try it yourself! ๐""",
unsafe_allow_html=True)
# Demo
# with st.form("dutch-toxic-comment-detection-input", clear_on_submit=True):
# selected_model = st.selectbox('Select a model:', model_names_to_URLs.keys(),
# )#index=0, format_func=special_internal_function, key=None, help=None, on_change=None, args=None, kwargs=None, *, disabled=False)
# text = st.text_area(
# label='Enter the comment you want to classify below (in Dutch):')
# _, rightmost_col = st.columns([6,1])
# submitted = rightmost_col.form_submit_button("Classify",
# help="Classify comment")
with st.form("article-inpu"):
# TODO: should probably set a minimum length of article or something
selected_article = st.selectbox('Select an article or provide your own:', list_all_filenames(),
)#index=0, format_func=special_internal_function, key=None, help=None, on_change=None, args=None, kwargs=None, *, disabled=False)
st.session_state.article_text = fetch_file_contents(selected_article)
article_text = st.text_area(
#label='Enter the comment you want to classify below (in Dutch):')
value = st.session_state.article_text)
_, rightmost_col = st.columns([6,1])
get_summary = rightmost_col.form_submit_button("Generate summary",
help="Generate summary for the given article text")
# Listener
if get_summary:
if article_text:
with st.spinner('Analysing comment...'):
#classify_comment(article_text, selected_model)
else:
st.error('**Error**: No comment to classify. Please provide a comment.')
# Results
if 'results' in st.session_state and st.session_state.results:
first = True
for result in st.session_state.results[::-1]:
if not first:
st.markdown("---")
st.markdown(f"Text:\n> {result['text']}")
col_1, col_2, col_3 = st.columns([1,2,2])
col_1.metric(label='', value=f"{result['emoji']}")
col_2.metric(label='Label', value=f"{result['label']}")
col_3.metric(label='Score', value=f"{result['score']:.3f}")
st.markdown(f"Token Attribution:\n{result['tokens_with_background']}",
unsafe_allow_html=True)
st.caption(f"Model: {result['model_name']}")
first = False