MatthiasCami8 commited on
Commit
76254b2
Β·
unverified Β·
1 Parent(s): 33d91ae

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +185 -0
app.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import streamlit as st
4
+ from bs4 import BeautifulSoup
5
+
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
+ from transformers import pipeline
8
+ from transformers_interpret import SequenceClassificationExplainer
9
+
10
+ # Map model names to URLs
11
+ model_names_to_URLs = {
12
+ 'ml6team/distilbert-base-dutch-cased-toxic-comments':
13
+ 'https://huggingface.co/ml6team/distilbert-base-dutch-cased-toxic-comments',
14
+ 'ml6team/robbert-dutch-base-toxic-comments':
15
+ 'https://huggingface.co/ml6team/robbert-dutch-base-toxic-comments',
16
+ }
17
+
18
+ about_page_markdown = f"""# 🀬 Dutch Toxic Comment Detection Space
19
+
20
+ Made by [ML6](https://ml6.eu/).
21
+
22
+ Token attribution is performed using [transformers-interpret](https://github.com/cdpierse/transformers-interpret).
23
+ """
24
+
25
+ regular_emojis = [
26
+ '😐', 'πŸ™‚', 'πŸ‘Ά', 'πŸ˜‡',
27
+ ]
28
+ undecided_emojis = [
29
+ '🀨', '🧐', 'πŸ₯Έ', 'πŸ₯΄', '🀷',
30
+ ]
31
+ potty_mouth_emojis = [
32
+ '🀐', 'πŸ‘Ώ', '😑', '🀬', '☠️', '☣️', '☒️',
33
+ ]
34
+
35
+ # Page setup
36
+ st.set_page_config(
37
+ page_title="Toxic Comment Detection Space",
38
+ page_icon="🀬",
39
+ layout="centered",
40
+ initial_sidebar_state="auto",
41
+ menu_items={
42
+ 'Get help': None,
43
+ 'Report a bug': None,
44
+ 'About': about_page_markdown,
45
+ }
46
+ )
47
+
48
+ # Model setup
49
+ @st.cache(allow_output_mutation=True,
50
+ suppress_st_warning=True,
51
+ show_spinner=False)
52
+ def load_pipeline(model_name):
53
+ with st.spinner('Loading model (this might take a while)...'):
54
+ toxicity_pipeline = pipeline(
55
+ 'text-classification',
56
+ model=model_name,
57
+ tokenizer=model_name)
58
+ cls_explainer = SequenceClassificationExplainer(
59
+ toxicity_pipeline.model,
60
+ toxicity_pipeline.tokenizer)
61
+ return toxicity_pipeline, cls_explainer
62
+
63
+
64
+ # Auxiliary functions
65
+ def format_explainer_html(html_string):
66
+ """Extract tokens with attribution-based background color."""
67
+ inside_token_prefix = '##'
68
+ soup = BeautifulSoup(html_string, 'html.parser')
69
+ p = soup.new_tag('p',
70
+ attrs={'style': 'color: black; background-color: white;'})
71
+ # Select token elements and remove model specific tokens
72
+ current_word = None
73
+ for token in soup.find_all('td')[-1].find_all('mark')[1:-1]:
74
+ text = token.font.text.strip()
75
+ if text.startswith(inside_token_prefix):
76
+ text = text[len(inside_token_prefix):]
77
+ else:
78
+ # Create a new span for each word (sequence of sub-tokens)
79
+ if current_word is not None:
80
+ p.append(current_word)
81
+ p.append(' ')
82
+ current_word = soup.new_tag('span')
83
+ token.string = text
84
+ token.attrs['style'] = f"{token.attrs['style']}; padding: 0.2em 0em;"
85
+ current_word.append(token)
86
+
87
+ # Add last word
88
+ p.append(current_word)
89
+
90
+ # Add left and right-padding to each word
91
+ for span in p.find_all('span'):
92
+ span.find_all('mark')[0].attrs['style'] = (
93
+ f"{span.find_all('mark')[0].attrs['style']}; padding-left: 0.2em;")
94
+ span.find_all('mark')[-1].attrs['style'] = (
95
+ f"{span.find_all('mark')[-1].attrs['style']}; padding-right: 0.2em;")
96
+
97
+ return p
98
+
99
+
100
+ def classify_comment(comment, selected_model):
101
+ """Classify the given comment and augment with additional information."""
102
+ toxicity_pipeline, cls_explainer = load_pipeline(selected_model)
103
+ result = toxicity_pipeline(comment)[0]
104
+ result['model_name'] = selected_model
105
+
106
+ # Add explanation
107
+ result['word_attribution'] = cls_explainer(comment, class_name="non-toxic")
108
+ result['visualitsation_html'] = cls_explainer.visualize()._repr_html_()
109
+ result['tokens_with_background'] = format_explainer_html(
110
+ result['visualitsation_html'])
111
+
112
+ # Choose emoji reaction
113
+ label, score = result['label'], result['score']
114
+ if label == 'toxic' and score > 0.1:
115
+ emoji = random.choice(potty_mouth_emojis)
116
+ elif label in ['non_toxic', 'non-toxic'] and score > 0.1:
117
+ emoji = random.choice(regular_emojis)
118
+ else:
119
+ emoji = random.choice(undecided_emojis)
120
+ result.update({'text': comment, 'emoji': emoji})
121
+
122
+ # Add result to session
123
+ st.session_state.results.append(result)
124
+
125
+
126
+ # Start session
127
+ if 'results' not in st.session_state:
128
+ st.session_state.results = []
129
+
130
+ # Page
131
+ st.title('🀬 Dutch Toxic Comment Detection')
132
+ st.markdown("""This demo showcases two Dutch toxic comment detection models.""")
133
+
134
+ # Introduction
135
+ st.markdown(f"""Both models were trained using a sequence classification task on a translated [Jigsaw Toxicity dataset](https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge) which contains toxic online comments.
136
+ The first model is a fine-tuned multilingual [DistilBERT](https://huggingface.co/distilbert-base-multilingual-cased) model whereas the second is a fine-tuned Dutch RoBERTa-based model called [RobBERT](https://huggingface.co/pdelobelle/robbert-v2-dutch-base).""")
137
+ st.markdown(f"""For a more comprehensive overview of the models check out their model card on πŸ€— Model Hub: [distilbert-base-dutch-toxic-comments]({model_names_to_URLs['ml6team/distilbert-base-dutch-cased-toxic-comments']}) and [RobBERT-dutch-base-toxic-comments]({model_names_to_URLs['ml6team/robbert-dutch-base-toxic-comments']}).
138
+ """)
139
+ st.markdown("""Enter a comment that you want to classify below. The model will determine the probability that it is toxic and highlights how much each token contributes to its decision:
140
+ <font color="black">
141
+ <span style="background-color: rgb(250, 219, 219); opacity: 1;">r</span><span style="background-color: rgb(244, 179, 179); opacity: 1;">e</span><span style="background-color: rgb(238, 135, 135); opacity: 1;">d</span>
142
+ </font>
143
+ tokens indicate toxicity whereas
144
+ <font color="black">
145
+ <span style="background-color: rgb(224, 251, 224); opacity: 1;">g</span><span style="background-color: rgb(197, 247, 197); opacity: 1;">re</span><span style="background-color: rgb(121, 236, 121); opacity: 1;">en</span>
146
+ </font> tokens indicate the opposite.
147
+
148
+ Try it yourself! πŸ‘‡""",
149
+ unsafe_allow_html=True)
150
+
151
+
152
+ # Demo
153
+ with st.form("dutch-toxic-comment-detection-input", clear_on_submit=True):
154
+ selected_model = st.selectbox('Select a model:', model_names_to_URLs.keys(),
155
+ )#index=0, format_func=special_internal_function, key=None, help=None, on_change=None, args=None, kwargs=None, *, disabled=False)
156
+ text = st.text_area(
157
+ label='Enter the comment you want to classify below (in Dutch):')
158
+ _, rightmost_col = st.columns([6,1])
159
+ submitted = rightmost_col.form_submit_button("Classify",
160
+ help="Classify comment")
161
+
162
+ # Listener
163
+ if submitted:
164
+ if text:
165
+ with st.spinner('Analysing comment...'):
166
+ classify_comment(text, selected_model)
167
+ else:
168
+ st.error('**Error**: No comment to classify. Please provide a comment.')
169
+
170
+ # Results
171
+ if 'results' in st.session_state and st.session_state.results:
172
+ first = True
173
+ for result in st.session_state.results[::-1]:
174
+ if not first:
175
+ st.markdown("---")
176
+ st.markdown(f"Text:\n> {result['text']}")
177
+ col_1, col_2, col_3 = st.columns([1,2,2])
178
+ col_1.metric(label='', value=f"{result['emoji']}")
179
+ col_2.metric(label='Label', value=f"{result['label']}")
180
+ col_3.metric(label='Score', value=f"{result['score']:.3f}")
181
+ st.markdown(f"Token Attribution:\n{result['tokens_with_background']}",
182
+ unsafe_allow_html=True)
183
+ st.caption(f"Model: {result['model_name']}")
184
+ first = False
185
+