File size: 1,040 Bytes
1a80035
06ba6f1
1a80035
8f8b328
 
 
 
 
 
81b6ae4
4791a7a
 
06ba6f1
81b6ae4
1a80035
79b6ab8
8f8b328
79b6ab8
06ba6f1
 
 
2d6cd9e
 
 
 
2945a40
7641eb4
06ba6f1
087b7c5
06ba6f1
 
db840ed
4791a7a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import streamlit as st
import json

from transformers import pipeline


classifier = pipeline("zero-shot-classification",
                      model="valhalla/distilbart-mnli-12-1")

with st.form('inputs'):
    input_text = st.text_area("Input text")
    input_label = st.text_input("Labels", placeholder="support, help, important")
    input_multi = st.checkbox('Allow multiple true classes', value=False)
    submit_button = st.form_submit_button(label='Submit')

if submit_button:

    labels = list(l.strip() for l in input_label.split(','))
    pred = classifier(input_text, labels, multi_class=input_multi)
    
    if input_multi:
        st.vega_lite_chart(pred, {'mark': {'type': 'bar', 'tooltip': False},
     'encoding': {
         'x': {'field': 'scores', 'type': 'quantitative'},
         'y': {'field': 'labels', 'type': 'nominal'},
           },
         }, use_container_width=True, height=250)
    else:
        out = f"Top predicted label is {pred['labels'][0]}"
    
        st.success(out)
    
    # st.markdown(pred)