import streamlit as st from transformers import pipeline # Load the zero-shot classification model classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") # Define Streamlit app def main(): # Set page title and favicon st.set_page_config(page_title="Zero-Shot Text Classification", page_icon=":rocket:") # App title and description st.title("Zero-Shot Text Classification") st.write( """ This app performs zero-shot text classification using the Facebook BART-Large-MNLI model. Enter a sentence and candidate labels, and the model will predict the most relevant label. """ ) # Input text box for the sentence to classify sequence_to_classify = st.text_input("Enter the sentence to classify:", key="input_sentence", type="default", value="") # Candidate labels input with help text st.text("Enter candidate labels separated by commas (e.g., travel, cooking, dancing):") candidate_labels = st.text_input("Candidate Labels:", key="input_labels", type="default", value="") # Confidence threshold slider confidence_threshold = st.slider("Confidence Threshold:", min_value=0.0, max_value=1.0, value=0.5, step=0.01, key="confidence_threshold") # Classification button if st.button("Classify", key="classify_button"): if sequence_to_classify and candidate_labels: # Split candidate labels into a list candidate_labels = [label.strip() for label in candidate_labels.split(",")] # Perform classification classification_result = classifier(sequence_to_classify, candidate_labels) # Find label with highest score max_score_index = classification_result["scores"].index(max(classification_result["scores"])) max_label = classification_result["labels"][max_score_index] max_score = classification_result["scores"][max_score_index] # Display classification results st.subheader("Classification Results:") st.markdown("---") st.markdown(f"**{max_label}**: {max_score:.2f}", unsafe_allow_html=True) st.markdown("---") for label, score in zip(classification_result["labels"], classification_result["scores"]): if label != max_label: if score >= confidence_threshold: st.text(f"'{label}': {score:.2f}") else: st.text(f"'{label}': Below threshold ({score:.2f})") if __name__ == "__main__": main()