harry-stark commited on
Commit
17a8518
1 Parent(s): a93f647

Added app files

Browse files
Files changed (4) hide show
  1. app.py +22 -0
  2. hf_model.py +16 -0
  3. requirements.txt +4 -0
  4. utils.py +17 -0
app.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+ import streamlit as st
3
+ from hf_model import classifier_zero,load_model
4
+ from utils import plot_result
5
+ classifier=load_model()
6
+ if __name__ == '__main__':
7
+ st.header("Zero Shot Classification")
8
+
9
+
10
+
11
+ sequence = st.text_area(label="Input Sequence")
12
+ labels = st.text_input('Possible topics (separated by `,`)', max_chars=1000)
13
+ labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
14
+ if len(labels) == 0 or len(sequence) == 0:
15
+ st.write('Enter some text and at least one possible topic to see predictions.')
16
+
17
+ multi_class = st.checkbox('Allow multiple correct topics', value=True)
18
+
19
+ with st.spinner('Classifying...'):
20
+ top_topics, scores = classifier_zero(classifier,sequence=sequence,labels=labels,multi_class=multi_class)
21
+ plot_result(top_topics[::-1][-10:], scores[::-1][-10:])
22
+
hf_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification,pipeline
2
+ import torch
3
+
4
+ def load_model():
5
+
6
+ model_name = "MoritzLaurer/DeBERTa-v3-base-mnli"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
+ classifier = pipeline(task='zero-shot-classification', model=model, tokenizer=tokenizer, framework='pt')
10
+ return classifier
11
+
12
+
13
+ def classifier_zero(classifier,sequence:str,labels:list,multi_class:bool):
14
+ outputs=classifier(sequence, labels,multi_label=multi_class)
15
+ return outputs['labels'], outputs['scores']
16
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers[sentencepiece]==4.11.0
2
+ streamlit
3
+ plotly
4
+ torch
utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import plotly.express as px
4
+ def plot_result(top_topics, scores):
5
+ top_topics = np.array(top_topics)
6
+ scores = np.array(scores)
7
+ scores *= 100
8
+ fig = px.bar(x=scores, y=top_topics, orientation='h',
9
+ labels={'x': 'Confidence', 'y': 'Label'},
10
+ text=scores,
11
+ range_x=(0,115),
12
+ title='Top Predictions',
13
+ color=np.linspace(0,1,len(scores)),
14
+ color_continuous_scale='GnBu')
15
+ fig.update(layout_coloraxis_showscale=False)
16
+ fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
17
+ st.plotly_chart(fig)