Spaces:
Runtime error
Runtime error
harry-stark
commited on
Commit
•
17a8518
1
Parent(s):
a93f647
Added app files
Browse files- app.py +22 -0
- hf_model.py +16 -0
- requirements.txt +4 -0
- utils.py +17 -0
app.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Sequence
|
2 |
+
import streamlit as st
|
3 |
+
from hf_model import classifier_zero,load_model
|
4 |
+
from utils import plot_result
|
5 |
+
classifier=load_model()
|
6 |
+
if __name__ == '__main__':
|
7 |
+
st.header("Zero Shot Classification")
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
sequence = st.text_area(label="Input Sequence")
|
12 |
+
labels = st.text_input('Possible topics (separated by `,`)', max_chars=1000)
|
13 |
+
labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
|
14 |
+
if len(labels) == 0 or len(sequence) == 0:
|
15 |
+
st.write('Enter some text and at least one possible topic to see predictions.')
|
16 |
+
|
17 |
+
multi_class = st.checkbox('Allow multiple correct topics', value=True)
|
18 |
+
|
19 |
+
with st.spinner('Classifying...'):
|
20 |
+
top_topics, scores = classifier_zero(classifier,sequence=sequence,labels=labels,multi_class=multi_class)
|
21 |
+
plot_result(top_topics[::-1][-10:], scores[::-1][-10:])
|
22 |
+
|
hf_model.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification,pipeline
|
2 |
+
import torch
|
3 |
+
|
4 |
+
def load_model():
|
5 |
+
|
6 |
+
model_name = "MoritzLaurer/DeBERTa-v3-base-mnli"
|
7 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
8 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
9 |
+
classifier = pipeline(task='zero-shot-classification', model=model, tokenizer=tokenizer, framework='pt')
|
10 |
+
return classifier
|
11 |
+
|
12 |
+
|
13 |
+
def classifier_zero(classifier,sequence:str,labels:list,multi_class:bool):
|
14 |
+
outputs=classifier(sequence, labels,multi_label=multi_class)
|
15 |
+
return outputs['labels'], outputs['scores']
|
16 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers[sentencepiece]==4.11.0
|
2 |
+
streamlit
|
3 |
+
plotly
|
4 |
+
torch
|
utils.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
import plotly.express as px
|
4 |
+
def plot_result(top_topics, scores):
|
5 |
+
top_topics = np.array(top_topics)
|
6 |
+
scores = np.array(scores)
|
7 |
+
scores *= 100
|
8 |
+
fig = px.bar(x=scores, y=top_topics, orientation='h',
|
9 |
+
labels={'x': 'Confidence', 'y': 'Label'},
|
10 |
+
text=scores,
|
11 |
+
range_x=(0,115),
|
12 |
+
title='Top Predictions',
|
13 |
+
color=np.linspace(0,1,len(scores)),
|
14 |
+
color_continuous_scale='GnBu')
|
15 |
+
fig.update(layout_coloraxis_showscale=False)
|
16 |
+
fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
|
17 |
+
st.plotly_chart(fig)
|