import streamlit as st from transformers import pipeline # Load the zero-shot classification model classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") # Define Streamlit app def main(): # Set page title and favicon st.set_page_config(page_title="Zero-Shot Text Classification", page_icon=":rocket:") # App title and description st.title("Zero-Shot Text Classification") st.markdown(""" This app performs zero-shot text classification using the Facebook BART-Large-MNLI model. Enter a sentence and candidate labels, and the model will predict the most relevant label. """) # Input text box for the sentence to classify sequence_to_classify = st.text_input("Enter the sentence to classify:") # Candidate labels input with help text st.text("Enter candidate labels separated by commas (e.g., travel, cooking, dancing):") candidate_labels = st.text_input("Candidate Labels:") # Confidence threshold slider confidence_threshold = st.slider("Confidence Threshold:", min_value=0.0, max_value=1.0, value=0.5, step=0.01) # Classification button if st.button("Classify"): if sequence_to_classify and candidate_labels: # Split candidate labels into a list candidate_labels = [label.strip() for label in candidate_labels.split(",")] # Perform classification classification_result = classifier(sequence_to_classify, candidate_labels) # Display classification results st.subheader("Classification Results:") for label, score in zip(classification_result["labels"], classification_result["scores"]): if score >= confidence_threshold: st.write(f"- {label}: {score}") else: st.write(f"- {label}: Below threshold ({score})") if __name__ == "__main__": main()