ashrestha's picture
Update app.py
db840ed
raw
history blame
654 Bytes
import streamlit as st
from transformers import pipeline
classifier = pipeline("zero-shot-classification",
model="valhalla/distilbart-mnli-12-1")
with st.form('inputs'):
input_text = st.text_area("Input text")
input_label = st.text_input("Labels", placeholder="support, help, important")
submit_button = st.form_submit_button(label='Submit')
if submit_button:
labels = list(l.strip() for l in input_label.split(','))
pred = classifier(input_text, labels, multi_class=True)
out = f"Top predicted labels are {', '.join(p for p in pred['labels'][0:2])}"
st.success(out)
# st.markdown(pred)