File size: 2,067 Bytes
a452638
 
dbf4eb2
7c7d7c6
57fe04c
 
 
dbf4eb2
7c7d7c6
 
 
 
 
 
 
dbf4eb2
57fe04c
 
 
 
 
 
 
 
a452638
57fe04c
 
7c7d7c6
57fe04c
 
 
aa13915
 
dbf4eb2
 
7c7d7c6
 
 
dbf4eb2
a452638
 
 
dbf4eb2
 
76b4c44
7768d4a
76b4c44
dbf4eb2
 
c0c5b72
 
f6d9cae
8e532c0
dbf4eb2
c0c5b72
 
dbf4eb2
c0c5b72
 
 
dbf4eb2
c0c5b72
dbf4eb2
c0c5b72
38f2cf5
c0c5b72
 
a452638
c0c5b72
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from io import BytesIO

import pandas as pd
import streamlit as st
import tokenizers
import torch
from transformers import Pipeline, pipeline

st.set_page_config(
    page_title="Zero-shot classification from tabular data",
    page_icon=None,
    layout="wide",
    initial_sidebar_state="auto",
    menu_items=None,
)


@st.cache(
    hash_funcs={
        torch.nn.parameter.Parameter: lambda _: None,
        tokenizers.Tokenizer: lambda _: None,
        tokenizers.AddedToken: lambda _: None,
    },
    allow_output_mutation=True,
    show_spinner=False,
)
def load_classifier() -> Pipeline:
    classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
    return classifier


with st.spinner(text="Setting stuff up related to the inference engine..."):
    classifier = load_classifier()

st.title("Zero-shot classification from tabular data")
st.text(
    "Upload an Excel table and perform zero-shot classification on a set of custom labels"
)

data = st.file_uploader(
    "Upload Excel file (it should contain a column named `text` in its header):"
)
labels = st.text_input("Enter comma-separated labels:")

# classify first N snippets only for faster inference
N = 10000

if st.button("Calculate labels"):

    try:
        labels_list = labels.split(",")
        table = pd.read_excel(data)
        table = table.head(N).reset_index(drop=True)

        prog_bar = st.progress(0)
        preds = []

        for i in range(len(table)):
            preds.append(classifier(table.loc[i, "text"], labels)["labels"][0])
            prog_bar.progress((i + 1) / len(table))

        table["label"] = preds

        st.table(table[["text", "label"]])

        buf = BytesIO()
        table[["text", "label"]].to_excel(buf)

        st.download_button(
            label="Download table", data=buf.getvalue(), file_name="output.xlsx"
        )

    except:
        st.error(
            "Something went wrong. Make sure you upload an Excel file containing a column named `text` and a set of comma-separated labels is provided"
        )