Tamara Adokeme commited on
Commit
09ff543
1 Parent(s): 6b36ae5

Initial classifier config

Browse files
Files changed (1) hide show
  1. app.py +346 -2
app.py CHANGED
@@ -1,4 +1,348 @@
 
 
 
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
1
+ ############ 1. IMPORTING LIBRARIES ############
2
+
3
+ # Import streamlit, requests for API calls, and pandas and numpy for data manipulation
4
+
5
  import streamlit as st
6
+ import requests
7
+ import pandas as pd
8
+ import numpy as np
9
+ from streamlit_tags import st_tags # to add labels on the fly!
10
+
11
+
12
+ ############ 2. SETTING UP THE PAGE LAYOUT AND TITLE ############
13
+
14
+ # `st.set_page_config` is used to display the default layout width, the title of the app, and the emoticon in the browser tab.
15
+
16
+ st.set_page_config(
17
+ layout="centered", page_title="Zero-Shot Text Classifier", page_icon="❄️"
18
+ )
19
+
20
+ ############ CREATE THE LOGO AND HEADING ############
21
+
22
+ # We create a set of columns to display the logo and the heading next to each other.
23
+
24
+
25
+ c1, c2 = st.columns([0.32, 2])
26
+
27
+ # The snowflake logo will be displayed in the first column, on the left.
28
+
29
+ with c1:
30
+
31
+ st.image(
32
+ "https://images.unsplash.com/photo-1508175800969-525c72a047dd?w=500&auto=format&fit=crop&q=60&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxzZWFyY2h8MTl8fGFmcm8lMjByb2JvdHxlbnwwfHwwfHx8MA%3D%3D",
33
+ width=85,
34
+ )
35
+
36
+
37
+ # The heading will be on the right.
38
+
39
+ with c2:
40
+
41
+ st.caption("")
42
+ st.title("Zero-Shot Text Classifier")
43
+
44
+
45
+ # We need to set up session state via st.session_state so that app interactions don't reset the app.
46
+
47
+ if not "valid_inputs_received" in st.session_state:
48
+ st.session_state["valid_inputs_received"] = False
49
+
50
+
51
+ ############ SIDEBAR CONTENT ############
52
+
53
+ st.sidebar.write("")
54
+
55
+ # For elements to be displayed in the sidebar, we need to add the sidebar element in the widget.
56
+
57
+ # We create a text input field for users to enter their API key.
58
+
59
+ API_KEY = st.sidebar.text_input(
60
+ "Enter your HuggingFace API key",
61
+ help="Once you created you HuggingFace account, you can get your free API token in your settings page: https://huggingface.co/settings/tokens",
62
+ type="password",
63
+ )
64
+
65
+ # Adding the HuggingFace API inference URL.
66
+ API_URL = "https://api-inference.huggingface.co/models/valhalla/distilbart-mnli-12-3"
67
+
68
+ # Now, let's create a Python dictionary to store the API headers.
69
+ headers = {"Authorization": f"Bearer {API_KEY}"}
70
+
71
+
72
+ st.sidebar.markdown("---")
73
+
74
+
75
+ # Let's add some info about the app to the sidebar.
76
+
77
+ st.sidebar.write(
78
+ """
79
+
80
+ App created by [Charly Wargnier](https://twitter.com/DataChaz) using [Streamlit](https://streamlit.io/)🎈 and [HuggingFace](https://huggingface.co/inference-api)'s [Distilbart-mnli-12-3](https://huggingface.co/valhalla/distilbart-mnli-12-3) model.
81
+
82
+ """
83
+ )
84
+
85
+
86
+ ############ TABBED NAVIGATION ############
87
+
88
+ # First, we're going to create a tabbed navigation for the app via st.tabs()
89
+ # tabInfo displays info about the app.
90
+ # tabMain displays the main app.
91
+
92
+ MainTab, InfoTab = st.tabs(["Main", "Info"])
93
+
94
+ with InfoTab:
95
+
96
+ st.subheader("What is Streamlit?")
97
+ st.markdown(
98
+ "[Streamlit](https://streamlit.io) is a Python library that allows the creation of interactive, data-driven web applications in Python."
99
+ )
100
+
101
+ st.subheader("Resources")
102
+ st.markdown(
103
+ """
104
+ - [Streamlit Documentation](https://docs.streamlit.io/)
105
+ - [Cheat sheet](https://docs.streamlit.io/library/cheatsheet)
106
+ - [Book](https://www.amazon.com/dp/180056550X) (Getting Started with Streamlit for Data Science)
107
+ """
108
+ )
109
+
110
+ st.subheader("Deploy")
111
+ st.markdown(
112
+ "You can quickly deploy Streamlit apps using [Streamlit Community Cloud](https://streamlit.io/cloud) in just a few clicks."
113
+ )
114
+
115
+
116
+ with MainTab:
117
+
118
+ # Then, we create a intro text for the app, which we wrap in a st.markdown() widget.
119
+
120
+ st.write("")
121
+ st.markdown(
122
+ """
123
+
124
+ Classify keyphrases on the fly with this mighty app. No training needed!
125
+
126
+ """
127
+ )
128
+
129
+ st.write("")
130
+
131
+ # Now, we create a form via `st.form` to collect the user inputs.
132
+
133
+ # All widget values will be sent to Streamlit in batch.
134
+ # It makes the app faster!
135
+
136
+ with st.form(key="my_form"):
137
+
138
+ ############ ST TAGS ############
139
+
140
+ # We initialize the st_tags component with default "labels"
141
+
142
+ # Here, we want to classify the text into one of the following user intents:
143
+ # Transactional
144
+ # Informational
145
+ # Navigational
146
+
147
+ labels_from_st_tags = st_tags(
148
+ value=["Transactional", "Informational", "Navigational"],
149
+ maxtags=3,
150
+ suggestions=["Transactional", "Informational", "Navigational"],
151
+ label="",
152
+ )
153
+
154
+ # The block of code below is to display some text samples to classify.
155
+ # This can of course be replaced with your own text samples.
156
+
157
+ # MAX_KEY_PHRASES is a variable that controls the number of phrases that can be pasted:
158
+ # The default in this app is 50 phrases. This can be changed to any number you like.
159
+
160
+ MAX_KEY_PHRASES = 50
161
+
162
+ new_line = "\n"
163
+
164
+ pre_defined_keyphrases = [
165
+ "I want to buy something",
166
+ "We have a question about a product",
167
+ "I want a refund through the Google Play store",
168
+ "Can I have a discount, please",
169
+ "Can I have the link to the product page?",
170
+ ]
171
+
172
+ # Python list comprehension to create a string from the list of keyphrases.
173
+ keyphrases_string = f"{new_line.join(map(str, pre_defined_keyphrases))}"
174
+
175
+ # The block of code below displays a text area
176
+ # So users can paste their phrases to classify
177
+
178
+ text = st.text_area(
179
+ # Instructions
180
+ "Enter keyphrases to classify",
181
+ # 'sample' variable that contains our keyphrases.
182
+ keyphrases_string,
183
+ # The height
184
+ height=200,
185
+ # The tooltip displayed when the user hovers over the text area.
186
+ help="At least two keyphrases for the classifier to work, one per line, "
187
+ + str(MAX_KEY_PHRASES)
188
+ + " keyphrases max in 'unlocked mode'. You can tweak 'MAX_KEY_PHRASES' in the code to change this",
189
+ key="1",
190
+ )
191
+
192
+ # The block of code below:
193
+
194
+ # 1. Converts the data st.text_area into a Python list.
195
+ # 2. It also removes duplicates and empty lines.
196
+ # 3. Raises an error if the user has entered more lines than in MAX_KEY_PHRASES.
197
+
198
+ text = text.split("\n") # Converts the pasted text to a Python list
199
+ linesList = [] # Creates an empty list
200
+ for x in text:
201
+ linesList.append(x) # Adds each line to the list
202
+ linesList = list(dict.fromkeys(linesList)) # Removes dupes
203
+ linesList = list(filter(None, linesList)) # Removes empty lines
204
+
205
+ if len(linesList) > MAX_KEY_PHRASES:
206
+ st.info(
207
+ f"❄️ Note that only the first "
208
+ + str(MAX_KEY_PHRASES)
209
+ + " keyphrases will be reviewed to preserve performance. Fork the repo and tweak 'MAX_KEY_PHRASES' in the code to increase that limit."
210
+ )
211
+
212
+ linesList = linesList[:MAX_KEY_PHRASES]
213
+
214
+ submit_button = st.form_submit_button(label="Submit")
215
+
216
+ ############ CONDITIONAL STATEMENTS ############
217
+
218
+ # Now, let us add conditional statements to check if users have entered valid inputs.
219
+ # E.g. If the user has pressed the 'submit button without text, without labels, and with only one label etc.
220
+ # The app will display a warning message.
221
+
222
+ if not submit_button and not st.session_state.valid_inputs_received:
223
+ st.stop()
224
+
225
+ elif submit_button and not text:
226
+ st.warning("❄️ There is no keyphrases to classify")
227
+ st.session_state.valid_inputs_received = False
228
+ st.stop()
229
+
230
+ elif submit_button and not labels_from_st_tags:
231
+ st.warning("❄️ You have not added any labels, please add some! ")
232
+ st.session_state.valid_inputs_received = False
233
+ st.stop()
234
+
235
+ elif submit_button and len(labels_from_st_tags) == 1:
236
+ st.warning("❄️ Please make sure to add at least two labels for classification")
237
+ st.session_state.valid_inputs_received = False
238
+ st.stop()
239
+
240
+ elif submit_button or st.session_state.valid_inputs_received:
241
+
242
+ if submit_button:
243
+
244
+ # The block of code below if for our session state.
245
+ # This is used to store the user's inputs so that they can be used later in the app.
246
+
247
+ st.session_state.valid_inputs_received = True
248
+
249
+ ############ MAKING THE API CALL ############
250
+
251
+ # First, we create a Python function to construct the API call.
252
+
253
+ def query(payload):
254
+ response = requests.post(API_URL, headers=headers, json=payload)
255
+ return response.json()
256
+
257
+ # The function will send an HTTP POST request to the API endpoint.
258
+ # This function has one argument: the payload
259
+ # The payload is the data we want to send to HugggingFace when we make an API request
260
+
261
+ # We create a list to store the outputs of the API call
262
+
263
+ list_for_api_output = []
264
+
265
+ # We create a 'for loop' that iterates through each keyphrase
266
+ # An API call will be made every time, for each keyphrase
267
+
268
+ # The payload is composed of:
269
+ # 1. the keyphrase
270
+ # 2. the labels
271
+ # 3. the 'wait_for_model' parameter set to "True", to avoid timeouts!
272
+
273
+ for row in linesList:
274
+ api_json_output = query(
275
+ {
276
+ "inputs": row,
277
+ "parameters": {"candidate_labels": labels_from_st_tags},
278
+ "options": {"wait_for_model": True},
279
+ }
280
+ )
281
+
282
+ # Let's have a look at the output of the API call
283
+ # st.write(api_json_output)
284
+
285
+ # All the results are appended to the empty list we created earlier
286
+ list_for_api_output.append(api_json_output)
287
+
288
+ # then we'll convert the list to a dataframe
289
+ df = pd.DataFrame.from_dict(list_for_api_output)
290
+
291
+ st.success("✅ Done!")
292
+
293
+ st.caption("")
294
+ st.markdown("### Check the results!")
295
+ st.caption("")
296
+
297
+ # st.write(df)
298
+
299
+ ############ DATA WRANGLING ON THE RESULTS ############
300
+ # Various data wrangling to get the data in the right format!
301
+
302
+ # List comprehension to convert the score from decimals to percentages
303
+ f = [[f"{x:.2%}" for x in row] for row in df["scores"]]
304
+
305
+ # Join the classification scores to the dataframe
306
+ df["classification scores"] = f
307
+
308
+ # Rename the column 'sequence' to 'keyphrase'
309
+ df.rename(columns={"sequence": "keyphrase"}, inplace=True)
310
+
311
+ # The API returns a list of all labels sorted by score. We only want the top label.
312
+
313
+ # For that, we need to select the first element in the 'labels' and 'classification scores' lists
314
+ df["label"] = df["labels"].str[0]
315
+ df["accuracy"] = df["classification scores"].str[0]
316
+
317
+ # Drop the columns we don't need
318
+ df.drop(["scores", "labels", "classification scores"], inplace=True, axis=1)
319
+
320
+ # st.write(df)
321
+
322
+ # We need to change the index. Index starts at 0, so we make it start at 1
323
+ df.index = np.arange(1, len(df) + 1)
324
+
325
+ # Display the dataframe
326
+ st.write(df)
327
+
328
+ cs, c1 = st.columns([2, 2])
329
+
330
+ # The code below is for the download button
331
+ # Cache the conversion to prevent computation on every rerun
332
+
333
+ with cs:
334
+
335
+ @st.experimental_memo
336
+ def convert_df(df):
337
+ return df.to_csv().encode("utf-8")
338
+
339
+ csv = convert_df(df)
340
+
341
+ st.caption("")
342
 
343
+ st.download_button(
344
+ label="Download results",
345
+ data=csv,
346
+ file_name="classification_results.csv",
347
+ mime="text/csv",
348
+ )