File size: 11,700 Bytes
09ff543
 
 
 
6b36ae5
09ff543
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b36ae5
09ff543
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
############ 1. IMPORTING LIBRARIES ############

# Import streamlit, requests for API calls, and pandas and numpy for data manipulation

import streamlit as st
import requests
import pandas as pd
import numpy as np
from streamlit_tags import st_tags  # to add labels on the fly!


############ 2. SETTING UP THE PAGE LAYOUT AND TITLE ############

# `st.set_page_config` is used to display the default layout width, the title of the app, and the emoticon in the browser tab.

st.set_page_config(
    layout="centered", page_title="Zero-Shot Text Classifier", page_icon="❄️"
)

############ CREATE THE LOGO AND HEADING ############

# We create a set of columns to display the logo and the heading next to each other.


c1, c2 = st.columns([0.32, 2])

# The snowflake logo will be displayed in the first column, on the left.

with c1:

    st.image(
        "https://images.unsplash.com/photo-1508175800969-525c72a047dd?w=500&auto=format&fit=crop&q=60&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxzZWFyY2h8MTl8fGFmcm8lMjByb2JvdHxlbnwwfHwwfHx8MA%3D%3D",
        width=85,
    )


# The heading will be on the right.

with c2:

    st.caption("")
    st.title("Zero-Shot Text Classifier")


# We need to set up session state via st.session_state so that app interactions don't reset the app.

if not "valid_inputs_received" in st.session_state:
    st.session_state["valid_inputs_received"] = False


############ SIDEBAR CONTENT ############

st.sidebar.write("")

# For elements to be displayed in the sidebar, we need to add the sidebar element in the widget.

# We create a text input field for users to enter their API key.

API_KEY = st.sidebar.text_input(
    "Enter your HuggingFace API key",
    help="Once you created you HuggingFace account, you can get your free API token in your settings page: https://huggingface.co/settings/tokens",
    type="password",
)

# Adding the HuggingFace API inference URL.
API_URL = "https://api-inference.huggingface.co/models/valhalla/distilbart-mnli-12-3"

# Now, let's create a Python dictionary to store the API headers.
headers = {"Authorization": f"Bearer {API_KEY}"}


st.sidebar.markdown("---")


# Let's add some info about the app to the sidebar.

st.sidebar.write(
    """

App created by [Charly Wargnier](https://twitter.com/DataChaz) using [Streamlit](https://streamlit.io/)🎈 and [HuggingFace](https://huggingface.co/inference-api)'s [Distilbart-mnli-12-3](https://huggingface.co/valhalla/distilbart-mnli-12-3) model.

"""
)


############ TABBED NAVIGATION ############

# First, we're going to create a tabbed navigation for the app via st.tabs()
# tabInfo displays info about the app.
# tabMain displays the main app.

MainTab, InfoTab = st.tabs(["Main", "Info"])

with InfoTab:

    st.subheader("What is Streamlit?")
    st.markdown(
        "[Streamlit](https://streamlit.io) is a Python library that allows the creation of interactive, data-driven web applications in Python."
    )

    st.subheader("Resources")
    st.markdown(
        """
    - [Streamlit Documentation](https://docs.streamlit.io/)
    - [Cheat sheet](https://docs.streamlit.io/library/cheatsheet)
    - [Book](https://www.amazon.com/dp/180056550X) (Getting Started with Streamlit for Data Science)
    """
    )

    st.subheader("Deploy")
    st.markdown(
        "You can quickly deploy Streamlit apps using [Streamlit Community Cloud](https://streamlit.io/cloud) in just a few clicks."
    )


with MainTab:

    # Then, we create a intro text for the app, which we wrap in a st.markdown() widget.

    st.write("")
    st.markdown(
        """

    Classify keyphrases on the fly with this mighty app. No training needed!

    """
    )

    st.write("")

    # Now, we create a form via `st.form` to collect the user inputs.

    # All widget values will be sent to Streamlit in batch.
    # It makes the app faster!

    with st.form(key="my_form"):

        ############ ST TAGS ############

        # We initialize the st_tags component with default "labels"

        # Here, we want to classify the text into one of the following user intents:
        # Transactional
        # Informational
        # Navigational

        labels_from_st_tags = st_tags(
            value=["Transactional", "Informational", "Navigational"],
            maxtags=3,
            suggestions=["Transactional", "Informational", "Navigational"],
            label="",
        )

        # The block of code below is to display some text samples to classify.
        # This can of course be replaced with your own text samples.

        # MAX_KEY_PHRASES is a variable that controls the number of phrases that can be pasted:
        # The default in this app is 50 phrases. This can be changed to any number you like.

        MAX_KEY_PHRASES = 50

        new_line = "\n"

        pre_defined_keyphrases = [
            "I want to buy something",
            "We have a question about a product",
            "I want a refund through the Google Play store",
            "Can I have a discount, please",
            "Can I have the link to the product page?",
        ]

        # Python list comprehension to create a string from the list of keyphrases.
        keyphrases_string = f"{new_line.join(map(str, pre_defined_keyphrases))}"

        # The block of code below displays a text area
        # So users can paste their phrases to classify

        text = st.text_area(
            # Instructions
            "Enter keyphrases to classify",
            # 'sample' variable that contains our keyphrases.
            keyphrases_string,
            # The height
            height=200,
            # The tooltip displayed when the user hovers over the text area.
            help="At least two keyphrases for the classifier to work, one per line, "
            + str(MAX_KEY_PHRASES)
            + " keyphrases max in 'unlocked mode'. You can tweak 'MAX_KEY_PHRASES' in the code to change this",
            key="1",
        )

        # The block of code below:

        # 1. Converts the data st.text_area into a Python list.
        # 2. It also removes duplicates and empty lines.
        # 3. Raises an error if the user has entered more lines than in MAX_KEY_PHRASES.

        text = text.split("\n")  # Converts the pasted text to a Python list
        linesList = []  # Creates an empty list
        for x in text:
            linesList.append(x)  # Adds each line to the list
        linesList = list(dict.fromkeys(linesList))  # Removes dupes
        linesList = list(filter(None, linesList))  # Removes empty lines

        if len(linesList) > MAX_KEY_PHRASES:
            st.info(
                f"❄️ Note that only the first "
                + str(MAX_KEY_PHRASES)
                + " keyphrases will be reviewed to preserve performance. Fork the repo and tweak 'MAX_KEY_PHRASES' in the code to increase that limit."
            )

            linesList = linesList[:MAX_KEY_PHRASES]

        submit_button = st.form_submit_button(label="Submit")

    ############ CONDITIONAL STATEMENTS ############

    # Now, let us add conditional statements to check if users have entered valid inputs.
    # E.g. If the user has pressed the 'submit button without text, without labels, and with only one label etc.
    # The app will display a warning message.

    if not submit_button and not st.session_state.valid_inputs_received:
        st.stop()

    elif submit_button and not text:
        st.warning("❄️ There is no keyphrases to classify")
        st.session_state.valid_inputs_received = False
        st.stop()

    elif submit_button and not labels_from_st_tags:
        st.warning("❄️ You have not added any labels, please add some! ")
        st.session_state.valid_inputs_received = False
        st.stop()

    elif submit_button and len(labels_from_st_tags) == 1:
        st.warning("❄️ Please make sure to add at least two labels for classification")
        st.session_state.valid_inputs_received = False
        st.stop()

    elif submit_button or st.session_state.valid_inputs_received:

        if submit_button:

            # The block of code below if for our session state.
            # This is used to store the user's inputs so that they can be used later in the app.

            st.session_state.valid_inputs_received = True

        ############ MAKING THE API CALL ############

        # First, we create a Python function to construct the API call.

        def query(payload):
            response = requests.post(API_URL, headers=headers, json=payload)
            return response.json()

        # The function will send an HTTP POST request to the API endpoint.
        # This function has one argument: the payload
        # The payload is the data we want to send to HugggingFace when we make an API request

        # We create a list to store the outputs of the API call

        list_for_api_output = []

        # We create a 'for loop' that iterates through each keyphrase
        # An API call will be made every time, for each keyphrase

        # The payload is composed of:
        #   1. the keyphrase
        #   2. the labels
        #   3. the 'wait_for_model' parameter set to "True", to avoid timeouts!

        for row in linesList:
            api_json_output = query(
                {
                    "inputs": row,
                    "parameters": {"candidate_labels": labels_from_st_tags},
                    "options": {"wait_for_model": True},
                }
            )

            # Let's have a look at the output of the API call
            # st.write(api_json_output)

            # All the results are appended to the empty list we created earlier
            list_for_api_output.append(api_json_output)

            # then we'll convert the list to a dataframe
            df = pd.DataFrame.from_dict(list_for_api_output)

        st.success("βœ… Done!")

        st.caption("")
        st.markdown("### Check the results!")
        st.caption("")

        # st.write(df)

        ############ DATA WRANGLING ON THE RESULTS ############
        # Various data wrangling to get the data in the right format!

        # List comprehension to convert the score from decimals to percentages
        f = [[f"{x:.2%}" for x in row] for row in df["scores"]]

        # Join the classification scores to the dataframe
        df["classification scores"] = f

        # Rename the column 'sequence' to 'keyphrase'
        df.rename(columns={"sequence": "keyphrase"}, inplace=True)

        # The API returns a list of all labels sorted by score. We only want the top label.

        # For that, we need to select the first element in the 'labels' and 'classification scores' lists
        df["label"] = df["labels"].str[0]
        df["accuracy"] = df["classification scores"].str[0]

        # Drop the columns we don't need
        df.drop(["scores", "labels", "classification scores"], inplace=True, axis=1)

        # st.write(df)

        # We need to change the index. Index starts at 0, so we make it start at 1
        df.index = np.arange(1, len(df) + 1)

        # Display the dataframe
        st.write(df)

        cs, c1 = st.columns([2, 2])

        # The code below is for the download button
        # Cache the conversion to prevent computation on every rerun

        with cs:

            @st.experimental_memo
            def convert_df(df):
                return df.to_csv().encode("utf-8")

            csv = convert_df(df)

            st.caption("")

            st.download_button(
                label="Download results",
                data=csv,
                file_name="classification_results.csv",
                mime="text/csv",
            )