import streamlit as st import json from transformers import pipeline classifier = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-1") with st.form('inputs'): input_text = st.text_area("Input text") input_label = st.text_input("Labels", placeholder="support, help, important") input_multi = st.checkbox('Allow multiple true classes', value=False) submit_button = st.form_submit_button(label='Submit') if submit_button: labels = list(l.strip() for l in input_label.split(',')) pred = classifier(input_text, labels, multi_class=input_multi) if input_multi: st.vega_lite_chart(pred, {'mark': {'type': 'bar', 'tooltip': False}, 'encoding': { 'x': {'field': 'scores', 'type': 'quantitative'}, 'y': {'field': 'labels', 'type': 'nominal'}, }, }, use_container_width=True) else: out = f"Top predicted labels are {', '.join(p for p in pred['labels'][0:2])}" st.success(out) # st.markdown(pred)