File size: 6,487 Bytes
a922691
e690399
a922691
 
 
 
 
 
 
e690399
 
a922691
 
 
 
85dd546
 
 
a922691
 
 
e690399
 
 
 
 
 
 
a922691
 
a8d963e
 
 
a922691
 
e690399
 
 
 
 
 
 
 
 
a922691
 
 
e690399
 
 
a922691
 
e690399
a922691
a8d963e
a922691
e690399
a922691
 
 
 
a8d963e
a922691
e690399
a922691
a8d963e
85dd546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e690399
 
 
 
 
a922691
 
 
 
 
 
 
e690399
 
a922691
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e690399
 
 
a922691
 
 
 
e690399
 
 
 
a922691
 
 
 
 
 
 
e690399
a922691
 
85dd546
a922691
85dd546
a922691
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import streamlit as st
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
import torch
import numpy as np
import contextlib
import plotly.express as px
import pandas as pd
from PIL import Image
import datetime
import os
import psutil

with open("hit_log.txt", mode='a') as file:
    file.write(str(datetime.datetime.now()) + '\n')


MAX_GRAPH_ROWS = 10

MODEL_DESC = {
    'Bart MNLI': """Bart with a classification head trained on MNLI.\n\nSequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
    'Bart MNLI + Yahoo Answers': """Bart with a classification head trained on MNLI and then further fine-tuned on Yahoo Answers topic classification.\n\nSequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
    'XLM Roberta XNLI (cross-lingual)': """XLM Roberta, a cross-lingual model, with a classification head trained on XNLI. Supported languages include: _English, French, Spanish, German, Greek, Bulgarian, Russian, Turkish, Arabic, Vietnamese, Thai, Chinese, Hindi, Swahili, and Urdu_.

Note that this model seems to be less reliable than the English-only models when classifying longer sequences.

Examples were automatically translated and may contain grammatical mistakes.

Sequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
}

ZSL_DESC = """*Update 2024/05/02: This app demo'd the use of pre-trained NLI models for zero-shot topic classification right around the 2020 announcement of GPT-3. It was originally released as a standalone app on the Hugging Face servers before the introduction of HF spaces. It is kept here for posterity.*

Recently, the NLP science community has begun to pay increasing attention to zero-shot and few-shot applications, such as in the [paper from OpenAI](https://arxiv.org/abs/2005.14165) introducing GPT-3. This demo shows how 🤗 Transformers can be used for zero-shot topic classification, the task of predicting a topic that the model has not been trained on."""

CODE_DESC = """```python
from transformers import pipeline
classifier = pipeline('zero-shot-classification',
                      model='{}')
hypothesis_template = 'This text is about {{}}.' # the template used in this demo

classifier(sequence, labels,
           hypothesis_template=hypothesis_template,
           multi_class=multi_class)
# {{'sequence' ..., 'labels': ..., 'scores': ...}}
```"""

model_ids = {
    'Bart MNLI': 'facebook/bart-large-mnli',
    'Bart MNLI + Yahoo Answers': 'joeddav/bart-large-mnli-yahoo-answers',
    'XLM Roberta XNLI (cross-lingual)': 'joeddav/xlm-roberta-large-xnli'
}

device = 0 if torch.cuda.is_available() else -1

@st.cache_resource
def load_models():
    return {id: AutoModelForSequenceClassification.from_pretrained(id) for id in model_ids.values()}

models = load_models()


@st.cache_resource(show_spinner=False)
def load_tokenizer(tok_id):
    return AutoTokenizer.from_pretrained(tok_id)

@st.cache_data(show_spinner=False, hash_funcs={
    torch.nn.Parameter: lambda _: None
})
def get_most_likely(nli_model_id, sequence, labels, hypothesis_template, multi_class):
    classifier = pipeline(
        'zero-shot-classification',
        model=models[nli_model_id],
        tokenizer=load_tokenizer(nli_model_id),
        device=device
    )
    outputs = classifier(
        sequence,
        candidate_labels=labels,
        hypothesis_template=hypothesis_template,
        multi_label=multi_class
    )
    return outputs['labels'], outputs['scores']

def load_examples(model_id):
    model_id_stripped = model_id.split('/')[-1]
    df = pd.read_json(f'texts-{model_id_stripped}.json')
    names = df.name.values.tolist()
    mapping = {df['name'].iloc[i]: (df['text'].iloc[i], df['labels'].iloc[i]) for i in range(len(names))}
    names.append('Custom')
    mapping['Custom'] = ('', '')
    return names, mapping

def plot_result(top_topics, scores):
    top_topics = np.array(top_topics)
    scores = np.array(scores)
    scores *= 100
    fig = px.bar(x=scores, y=top_topics, orientation='h', 
                 labels={'x': 'Confidence', 'y': 'Label'},
                 text=scores,
                 range_x=(0,115),
                 title='Top Predictions',
                 color=np.linspace(0,1,len(scores)),
                 color_continuous_scale='GnBu')
    fig.update(layout_coloraxis_showscale=False)
    fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
    st.plotly_chart(fig)


def main():
    with open("style.css") as f:
        st.markdown('<style>{}</style>'.format(f.read()), unsafe_allow_html=True)

    logo = Image.open('huggingface_logo.png')
    st.sidebar.image(logo, width=120)
    st.sidebar.markdown(ZSL_DESC)
    model_desc = st.sidebar.selectbox('Model', list(MODEL_DESC.keys()), 0)
    do_print_code = st.sidebar.checkbox('Show code snippet', False)
    st.sidebar.markdown('#### Model Description')
    st.sidebar.markdown(MODEL_DESC[model_desc])
    st.sidebar.markdown('Originally proposed by [Yin et al. (2019)](https://arxiv.org/abs/1909.00161). Read more in our [blog post](https://joeddav.github.io/blog/2020/05/29/ZSL.html).')
    
    model_id = model_ids[model_desc]
    ex_names, ex_map = load_examples(model_id)

    st.title('Zero Shot Topic Classification')
    example = st.selectbox('Choose an example', ex_names)
    height = min((len(ex_map[example][0].split()) + 1) * 2, 200)
    sequence = st.text_area('Text', ex_map[example][0], key='sequence', height=height)
    labels = st.text_input('Possible topics (separated by `,`)', ex_map[example][1], max_chars=1000)
    multi_class = st.checkbox('Allow multiple correct topics', value=True)

    hypothesis_template = "This text is about {}."

    labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
    if len(labels) == 0 or len(sequence) == 0:
        st.write('Enter some text and at least one possible topic to see predictions.')
        return

    if do_print_code:
        st.markdown(CODE_DESC.format(model_id))

    with st.spinner('Classifying...'):
        top_topics, scores = get_most_likely(model_id, sequence, labels, hypothesis_template, multi_class)

    plot_result(top_topics[::-1][-MAX_GRAPH_ROWS:], scores[::-1][-MAX_GRAPH_ROWS:])



if __name__ == '__main__':
    main()