MatthiasCami8
commited on
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import streamlit as st
|
4 |
+
from bs4 import BeautifulSoup
|
5 |
+
|
6 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
7 |
+
from transformers import pipeline
|
8 |
+
from transformers_interpret import SequenceClassificationExplainer
|
9 |
+
|
10 |
+
# Map model names to URLs
|
11 |
+
model_names_to_URLs = {
|
12 |
+
'ml6team/distilbert-base-dutch-cased-toxic-comments':
|
13 |
+
'https://huggingface.co/ml6team/distilbert-base-dutch-cased-toxic-comments',
|
14 |
+
'ml6team/robbert-dutch-base-toxic-comments':
|
15 |
+
'https://huggingface.co/ml6team/robbert-dutch-base-toxic-comments',
|
16 |
+
}
|
17 |
+
|
18 |
+
about_page_markdown = f"""# π€¬ Dutch Toxic Comment Detection Space
|
19 |
+
|
20 |
+
Made by [ML6](https://ml6.eu/).
|
21 |
+
|
22 |
+
Token attribution is performed using [transformers-interpret](https://github.com/cdpierse/transformers-interpret).
|
23 |
+
"""
|
24 |
+
|
25 |
+
regular_emojis = [
|
26 |
+
'π', 'π', 'πΆ', 'π',
|
27 |
+
]
|
28 |
+
undecided_emojis = [
|
29 |
+
'π€¨', 'π§', 'π₯Έ', 'π₯΄', 'π€·',
|
30 |
+
]
|
31 |
+
potty_mouth_emojis = [
|
32 |
+
'π€', 'πΏ', 'π‘', 'π€¬', 'β οΈ', 'β£οΈ', 'β’οΈ',
|
33 |
+
]
|
34 |
+
|
35 |
+
# Page setup
|
36 |
+
st.set_page_config(
|
37 |
+
page_title="Toxic Comment Detection Space",
|
38 |
+
page_icon="π€¬",
|
39 |
+
layout="centered",
|
40 |
+
initial_sidebar_state="auto",
|
41 |
+
menu_items={
|
42 |
+
'Get help': None,
|
43 |
+
'Report a bug': None,
|
44 |
+
'About': about_page_markdown,
|
45 |
+
}
|
46 |
+
)
|
47 |
+
|
48 |
+
# Model setup
|
49 |
+
@st.cache(allow_output_mutation=True,
|
50 |
+
suppress_st_warning=True,
|
51 |
+
show_spinner=False)
|
52 |
+
def load_pipeline(model_name):
|
53 |
+
with st.spinner('Loading model (this might take a while)...'):
|
54 |
+
toxicity_pipeline = pipeline(
|
55 |
+
'text-classification',
|
56 |
+
model=model_name,
|
57 |
+
tokenizer=model_name)
|
58 |
+
cls_explainer = SequenceClassificationExplainer(
|
59 |
+
toxicity_pipeline.model,
|
60 |
+
toxicity_pipeline.tokenizer)
|
61 |
+
return toxicity_pipeline, cls_explainer
|
62 |
+
|
63 |
+
|
64 |
+
# Auxiliary functions
|
65 |
+
def format_explainer_html(html_string):
|
66 |
+
"""Extract tokens with attribution-based background color."""
|
67 |
+
inside_token_prefix = '##'
|
68 |
+
soup = BeautifulSoup(html_string, 'html.parser')
|
69 |
+
p = soup.new_tag('p',
|
70 |
+
attrs={'style': 'color: black; background-color: white;'})
|
71 |
+
# Select token elements and remove model specific tokens
|
72 |
+
current_word = None
|
73 |
+
for token in soup.find_all('td')[-1].find_all('mark')[1:-1]:
|
74 |
+
text = token.font.text.strip()
|
75 |
+
if text.startswith(inside_token_prefix):
|
76 |
+
text = text[len(inside_token_prefix):]
|
77 |
+
else:
|
78 |
+
# Create a new span for each word (sequence of sub-tokens)
|
79 |
+
if current_word is not None:
|
80 |
+
p.append(current_word)
|
81 |
+
p.append(' ')
|
82 |
+
current_word = soup.new_tag('span')
|
83 |
+
token.string = text
|
84 |
+
token.attrs['style'] = f"{token.attrs['style']}; padding: 0.2em 0em;"
|
85 |
+
current_word.append(token)
|
86 |
+
|
87 |
+
# Add last word
|
88 |
+
p.append(current_word)
|
89 |
+
|
90 |
+
# Add left and right-padding to each word
|
91 |
+
for span in p.find_all('span'):
|
92 |
+
span.find_all('mark')[0].attrs['style'] = (
|
93 |
+
f"{span.find_all('mark')[0].attrs['style']}; padding-left: 0.2em;")
|
94 |
+
span.find_all('mark')[-1].attrs['style'] = (
|
95 |
+
f"{span.find_all('mark')[-1].attrs['style']}; padding-right: 0.2em;")
|
96 |
+
|
97 |
+
return p
|
98 |
+
|
99 |
+
|
100 |
+
def classify_comment(comment, selected_model):
|
101 |
+
"""Classify the given comment and augment with additional information."""
|
102 |
+
toxicity_pipeline, cls_explainer = load_pipeline(selected_model)
|
103 |
+
result = toxicity_pipeline(comment)[0]
|
104 |
+
result['model_name'] = selected_model
|
105 |
+
|
106 |
+
# Add explanation
|
107 |
+
result['word_attribution'] = cls_explainer(comment, class_name="non-toxic")
|
108 |
+
result['visualitsation_html'] = cls_explainer.visualize()._repr_html_()
|
109 |
+
result['tokens_with_background'] = format_explainer_html(
|
110 |
+
result['visualitsation_html'])
|
111 |
+
|
112 |
+
# Choose emoji reaction
|
113 |
+
label, score = result['label'], result['score']
|
114 |
+
if label == 'toxic' and score > 0.1:
|
115 |
+
emoji = random.choice(potty_mouth_emojis)
|
116 |
+
elif label in ['non_toxic', 'non-toxic'] and score > 0.1:
|
117 |
+
emoji = random.choice(regular_emojis)
|
118 |
+
else:
|
119 |
+
emoji = random.choice(undecided_emojis)
|
120 |
+
result.update({'text': comment, 'emoji': emoji})
|
121 |
+
|
122 |
+
# Add result to session
|
123 |
+
st.session_state.results.append(result)
|
124 |
+
|
125 |
+
|
126 |
+
# Start session
|
127 |
+
if 'results' not in st.session_state:
|
128 |
+
st.session_state.results = []
|
129 |
+
|
130 |
+
# Page
|
131 |
+
st.title('π€¬ Dutch Toxic Comment Detection')
|
132 |
+
st.markdown("""This demo showcases two Dutch toxic comment detection models.""")
|
133 |
+
|
134 |
+
# Introduction
|
135 |
+
st.markdown(f"""Both models were trained using a sequence classification task on a translated [Jigsaw Toxicity dataset](https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge) which contains toxic online comments.
|
136 |
+
The first model is a fine-tuned multilingual [DistilBERT](https://huggingface.co/distilbert-base-multilingual-cased) model whereas the second is a fine-tuned Dutch RoBERTa-based model called [RobBERT](https://huggingface.co/pdelobelle/robbert-v2-dutch-base).""")
|
137 |
+
st.markdown(f"""For a more comprehensive overview of the models check out their model card on π€ Model Hub: [distilbert-base-dutch-toxic-comments]({model_names_to_URLs['ml6team/distilbert-base-dutch-cased-toxic-comments']}) and [RobBERT-dutch-base-toxic-comments]({model_names_to_URLs['ml6team/robbert-dutch-base-toxic-comments']}).
|
138 |
+
""")
|
139 |
+
st.markdown("""Enter a comment that you want to classify below. The model will determine the probability that it is toxic and highlights how much each token contributes to its decision:
|
140 |
+
<font color="black">
|
141 |
+
<span style="background-color: rgb(250, 219, 219); opacity: 1;">r</span><span style="background-color: rgb(244, 179, 179); opacity: 1;">e</span><span style="background-color: rgb(238, 135, 135); opacity: 1;">d</span>
|
142 |
+
</font>
|
143 |
+
tokens indicate toxicity whereas
|
144 |
+
<font color="black">
|
145 |
+
<span style="background-color: rgb(224, 251, 224); opacity: 1;">g</span><span style="background-color: rgb(197, 247, 197); opacity: 1;">re</span><span style="background-color: rgb(121, 236, 121); opacity: 1;">en</span>
|
146 |
+
</font> tokens indicate the opposite.
|
147 |
+
|
148 |
+
Try it yourself! π""",
|
149 |
+
unsafe_allow_html=True)
|
150 |
+
|
151 |
+
|
152 |
+
# Demo
|
153 |
+
with st.form("dutch-toxic-comment-detection-input", clear_on_submit=True):
|
154 |
+
selected_model = st.selectbox('Select a model:', model_names_to_URLs.keys(),
|
155 |
+
)#index=0, format_func=special_internal_function, key=None, help=None, on_change=None, args=None, kwargs=None, *, disabled=False)
|
156 |
+
text = st.text_area(
|
157 |
+
label='Enter the comment you want to classify below (in Dutch):')
|
158 |
+
_, rightmost_col = st.columns([6,1])
|
159 |
+
submitted = rightmost_col.form_submit_button("Classify",
|
160 |
+
help="Classify comment")
|
161 |
+
|
162 |
+
# Listener
|
163 |
+
if submitted:
|
164 |
+
if text:
|
165 |
+
with st.spinner('Analysing comment...'):
|
166 |
+
classify_comment(text, selected_model)
|
167 |
+
else:
|
168 |
+
st.error('**Error**: No comment to classify. Please provide a comment.')
|
169 |
+
|
170 |
+
# Results
|
171 |
+
if 'results' in st.session_state and st.session_state.results:
|
172 |
+
first = True
|
173 |
+
for result in st.session_state.results[::-1]:
|
174 |
+
if not first:
|
175 |
+
st.markdown("---")
|
176 |
+
st.markdown(f"Text:\n> {result['text']}")
|
177 |
+
col_1, col_2, col_3 = st.columns([1,2,2])
|
178 |
+
col_1.metric(label='', value=f"{result['emoji']}")
|
179 |
+
col_2.metric(label='Label', value=f"{result['label']}")
|
180 |
+
col_3.metric(label='Score', value=f"{result['score']:.3f}")
|
181 |
+
st.markdown(f"Token Attribution:\n{result['tokens_with_background']}",
|
182 |
+
unsafe_allow_html=True)
|
183 |
+
st.caption(f"Model: {result['model_name']}")
|
184 |
+
first = False
|
185 |
+
|