ajeetkumar01's picture
updated the app view
9c9e66f verified
import streamlit as st
from transformers import pipeline
# Load the zero-shot classification model
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
# Define Streamlit app
def main():
# Set page title and favicon
st.set_page_config(page_title="Zero-Shot Text Classification", page_icon=":rocket:")
# App title and description
st.title("Zero-Shot Text Classification")
st.write(
"""
This app performs zero-shot text classification using the Facebook BART-Large-MNLI model.
Enter a sentence and candidate labels, and the model will predict the most relevant label.
"""
)
# Input text box for the sentence to classify
sequence_to_classify = st.text_input("Enter the sentence to classify:", key="input_sentence",
type="default", value="")
# Candidate labels input with help text
st.text("Enter candidate labels separated by commas (e.g., travel, cooking, dancing):")
candidate_labels = st.text_input("Candidate Labels:", key="input_labels", type="default", value="")
# Confidence threshold slider
confidence_threshold = st.slider("Confidence Threshold:", min_value=0.0, max_value=1.0, value=0.5, step=0.01,
key="confidence_threshold")
# Classification button
if st.button("Classify", key="classify_button"):
if sequence_to_classify and candidate_labels:
# Split candidate labels into a list
candidate_labels = [label.strip() for label in candidate_labels.split(",")]
# Perform classification
classification_result = classifier(sequence_to_classify, candidate_labels)
# Find label with highest score
max_score_index = classification_result["scores"].index(max(classification_result["scores"]))
max_label = classification_result["labels"][max_score_index]
max_score = classification_result["scores"][max_score_index]
# Display classification results
st.subheader("Classification Results:")
st.markdown("---")
st.markdown(f"**{max_label}**: {max_score:.2f}", unsafe_allow_html=True)
st.markdown("---")
for label, score in zip(classification_result["labels"], classification_result["scores"]):
if label != max_label:
if score >= confidence_threshold:
st.text(f"'{label}': {score:.2f}")
else:
st.text(f"'{label}': Below threshold ({score:.2f})")
if __name__ == "__main__":
main()