Spaces:
Runtime error
Runtime error
File size: 6,214 Bytes
a922691 e690399 a922691 e690399 a922691 e690399 a922691 e690399 a922691 e690399 a922691 e690399 a922691 e690399 a922691 e690399 a922691 e690399 a922691 e690399 a922691 e690399 a922691 e690399 a922691 e690399 a922691 e690399 a922691 e690399 a922691 e690399 a922691 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import streamlit as st
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
import torch
import numpy as np
import contextlib
import plotly.express as px
import pandas as pd
from PIL import Image
import datetime
import os
import psutil
with open("hit_log.txt", mode='a') as file:
file.write(str(datetime.datetime.now()) + '\n')
MODEL_DESC = {
'Bart MNLI': """Bart with a classification head trained on MNLI.\n\nSequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
'Bart MNLI + Yahoo Answers': """Bart with a classification head trained on MNLI and then further fine-tuned on Yahoo Answers topic classification.\n\nSequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
'XLM Roberta XNLI (cross-lingual)': """XLM Roberta, a cross-lingual model, with a classification head trained on XNLI. Supported languages include: _English, French, Spanish, German, Greek, Bulgarian, Russian, Turkish, Arabic, Vietnamese, Thai, Chinese, Hindi, Swahili, and Urdu_.
Note that this model seems to be less reliable than the English-only models when classifying longer sequences.
Examples were automatically translated and may contain grammatical mistakes.
Sequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
}
ZSL_DESC = """Recently, the NLP science community has begun to pay increasing attention to zero-shot and few-shot applications, such as in the [paper from OpenAI](https://arxiv.org/abs/2005.14165) introducing GPT-3. This demo shows how 🤗 Transformers can be used for zero-shot topic classification, the task of predicting a topic that the model has not been trained on."""
CODE_DESC = """```python
from transformers import pipeline
classifier = pipeline('zero-shot-classification',
model='{}')
hypothesis_template = 'This text is about {{}}.' # the template used in this demo
classifier(sequence, labels,
hypothesis_template=hypothesis_template,
multi_class=multi_class)
# {{'sequence' ..., 'labels': ..., 'scores': ...}}
```"""
model_ids = {
'Bart MNLI': 'facebook/bart-large-mnli',
'Bart MNLI + Yahoo Answers': 'joeddav/bart-large-mnli-yahoo-answers',
'XLM Roberta XNLI (cross-lingual)': 'joeddav/xlm-roberta-large-xnli'
}
device = 0 if torch.cuda.is_available() else -1
@st.cache(allow_output_mutation=True)
def load_models():
return {id: AutoModelForSequenceClassification.from_pretrained(id) for id in model_ids.values()}
models = load_models()
@st.cache(allow_output_mutation=True, show_spinner=False)
def load_tokenizer(tok_id):
return AutoTokenizer.from_pretrained(tok_id)
@st.cache(allow_output_mutation=True, show_spinner=False)
def get_most_likely(nli_model_id, sequence, labels, hypothesis_template, multi_class, do_print_code):
classifier = pipeline('zero-shot-classification', model=models[nli_model_id], tokenizer=load_tokenizer(nli_model_id), device=device)
outputs = classifier(sequence, labels, hypothesis_template, multi_class)
return outputs['labels'], outputs['scores']
def load_examples(model_id):
model_id_stripped = model_id.split('/')[-1]
df = pd.read_json(f'texts-{model_id_stripped}.json')
names = df.name.values.tolist()
mapping = {df['name'].iloc[i]: (df['text'].iloc[i], df['labels'].iloc[i]) for i in range(len(names))}
names.append('Custom')
mapping['Custom'] = ('', '')
return names, mapping
def plot_result(top_topics, scores):
top_topics = np.array(top_topics)
scores = np.array(scores)
scores *= 100
fig = px.bar(x=scores, y=top_topics, orientation='h',
labels={'x': 'Confidence', 'y': 'Label'},
text=scores,
range_x=(0,115),
title='Top Predictions',
color=np.linspace(0,1,len(scores)),
color_continuous_scale='GnBu')
fig.update(layout_coloraxis_showscale=False)
fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
st.plotly_chart(fig)
def main():
with open("style.css") as f:
st.markdown('<style>{}</style>'.format(f.read()), unsafe_allow_html=True)
logo = Image.open('huggingface_logo.png')
st.sidebar.image(logo, width=120)
st.sidebar.markdown(ZSL_DESC)
model_desc = st.sidebar.selectbox('Model', list(MODEL_DESC.keys()), 0)
do_print_code = st.sidebar.checkbox('Show code snippet', False)
st.sidebar.markdown('#### Model Description')
st.sidebar.markdown(MODEL_DESC[model_desc])
st.sidebar.markdown('Originally proposed by [Yin et al. (2019)](https://arxiv.org/abs/1909.00161). Read more in our [blog post](https://joeddav.github.io/blog/2020/05/29/ZSL.html).')
model_id = model_ids[model_desc]
ex_names, ex_map = load_examples(model_id)
st.title('Zero Shot Topic Classification')
example = st.selectbox('Choose an example', ex_names)
height = min((len(ex_map[example][0].split()) + 1) * 2, 200)
sequence = st.text_area('Text', ex_map[example][0], key='sequence', height=height)
labels = st.text_input('Possible topics (separated by `,`)', ex_map[example][1], max_chars=1000)
multi_class = st.checkbox('Allow multiple correct topics', value=True)
hypothesis_template = "This text is about {}."
labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
if len(labels) == 0 or len(sequence) == 0:
st.write('Enter some text and at least one possible topic to see predictions.')
return
if do_print_code:
st.markdown(CODE_DESC.format(model_id))
with st.spinner('Classifying...'):
top_topics, scores = get_most_likely(model_id, sequence, labels, hypothesis_template, multi_class, do_print_code)
plot_result(top_topics[::-1][-10:], scores[::-1][-10:])
if "socat" not in [p.name() for p in psutil.process_iter()]:
os.system('socat tcp-listen:8000,reuseaddr,fork tcp:localhost:8001 &')
if __name__ == '__main__':
main()
|