ajeetkumar01's picture
updated script for top result classifications
6fcc580 verified
raw
history blame
No virus
3.18 kB
import streamlit as st
from transformers import pipeline
# Load the zero-shot classification model
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
# Define Streamlit app
def main():
# Set page title and favicon
st.set_page_config(
page_title="Zero-Shot Text Classification",
page_icon=":rocket:",
layout="wide", # Set layout to wide for better spacing
initial_sidebar_state="expanded" # Expand sidebar by default
)
# App title and description with colorful text
st.title(":paintbrush: Zero-Shot Text Classification")
st.markdown(
"""
This app performs zero-shot text classification using the Facebook BART-Large-MNLI model.
Enter a sentence and candidate labels, and the model will predict the most relevant label.
"""
)
# Create a two-column layout
col1, col2 = st.columns([1, 2])
# Left pane: Input elements
with col1:
# Input text box for the sentence to classify
sequence_to_classify = st.text_input("Enter the sentence to classify:")
# Candidate labels input with help text
st.text("Enter candidate labels separated by commas (e.g., travel, cooking, dancing):")
candidate_labels = st.text_input("Candidate Labels:")
# Confidence threshold slider with colorful track
confidence_threshold = st.slider(
"Confidence Threshold:",
min_value=0.0,
max_value=1.0,
value=0.5,
step=0.01,
key="confidence_threshold",
help="Move the slider to adjust the confidence threshold."
)
# Classification button with colorful background
classify_button = st.button(
"Classify",
key="classify_button",
help="Click the button to classify the input text with the provided labels."
)
# Right pane: Results
with col2:
if classify_button:
if sequence_to_classify and candidate_labels:
# Split candidate labels into a list
candidate_labels = [label.strip() for label in candidate_labels.split(",")]
# Perform classification
classification_result = classifier(sequence_to_classify, candidate_labels)
# Find label with highest score
max_score_index = classification_result["scores"].index(max(classification_result["scores"]))
max_label = classification_result["labels"][max_score_index]
max_score = classification_result["scores"][max_score_index]
# Display only the label with the highest score
if max_score >= confidence_threshold:
st.subheader("Classification Result:")
st.write(f"- **{max_label}**: {max_score:.2f}", unsafe_allow_html=True)
else:
st.subheader("Classification Result:")
st.write(f"- <span style='color: #888;'>{max_label}:</span> Below threshold ({max_score:.2f})", unsafe_allow_html=True)
if __name__ == "__main__":
main()