import streamlit as st from transformers import pipeline # Load the zero-shot classification model classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") # Define Streamlit app def main(): # Set page title and favicon st.set_page_config( page_title="Zero-Shot Text Classification", page_icon=":rocket:", layout="wide", # Set layout to wide for better spacing initial_sidebar_state="expanded" # Expand sidebar by default ) # App title and description with colorful text st.title("Zero-Shot Text Classification") st.markdown( """ This app performs zero-shot text classification using the Facebook BART-Large-MNLI model. Enter a sentence and candidate labels, and the model will predict the most relevant label. """ ) # Create a two-column layout col1, col2 = st.columns([1, 2]) # Left pane: Input elements with col1: # Input text box for the sentence to classify sequence_to_classify = st.text_input("Enter the sentence to classify:") # Candidate labels input with help text st.text("Enter candidate labels separated by commas (e.g., travel, cooking, dancing):") candidate_labels = st.text_input("Candidate Labels:") # Confidence threshold slider with colorful track confidence_threshold = st.slider( "Confidence Threshold:", min_value=0.0, max_value=1.0, value=0.5, step=0.01, key="confidence_threshold", help="Move the slider to adjust the confidence threshold." ) # Classification button with colorful background classify_button = st.button( "Classify", key="classify_button", help="Click the button to classify the input text with the provided labels." ) # Right pane: Results with col2: if classify_button: if sequence_to_classify and candidate_labels: # Split candidate labels into a list candidate_labels = [label.strip() for label in candidate_labels.split(",")] # Perform classification classification_result = classifier(sequence_to_classify, candidate_labels) # Find label with highest score max_score_index = classification_result["scores"].index(max(classification_result["scores"])) max_label = classification_result["labels"][max_score_index] max_score = classification_result["scores"][max_score_index] # Display only the label with the highest score if max_score >= confidence_threshold: st.subheader("Classification Result:") st.write(f"- **{max_label}**: {max_score:.2f}", unsafe_allow_html=True) else: st.subheader("Classification Result:") st.write(f"- {max_label}: Below threshold ({max_score:.2f})", unsafe_allow_html=True) if __name__ == "__main__": main()