ajeetkumar01's picture
Update app.py
6f08be9 verified
raw
history blame
1.92 kB
import streamlit as st
from transformers import pipeline
# Load the zero-shot classification model
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
# Define Streamlit app
def main():
# Set page title and favicon
st.set_page_config(page_title="Zero-Shot Text Classification", page_icon=":rocket:")
# App title and description
st.title("Zero-Shot Text Classification")
st.markdown("""
This app performs zero-shot text classification using the Facebook BART-Large-MNLI model.
Enter a sentence and candidate labels, and the model will predict the most relevant label.
""")
# Input text box for the sentence to classify
sequence_to_classify = st.text_input("Enter the sentence to classify:")
# Candidate labels input with help text
st.text("Enter candidate labels separated by commas (e.g., travel, cooking, dancing):")
candidate_labels = st.text_input("Candidate Labels:")
# Confidence threshold slider
confidence_threshold = st.slider("Confidence Threshold:", min_value=0.0, max_value=1.0, value=0.5, step=0.01)
# Classification button
if st.button("Classify"):
if sequence_to_classify and candidate_labels:
# Split candidate labels into a list
candidate_labels = [label.strip() for label in candidate_labels.split(",")]
# Perform classification
classification_result = classifier(sequence_to_classify, candidate_labels)
# Display classification results
st.subheader("Classification Results:")
for label, score in zip(classification_result["labels"], classification_result["scores"]):
if score >= confidence_threshold:
st.write(f"- {label}: {score}")
else:
st.write(f"- {label}: Below threshold ({score})")
if __name__ == "__main__":
main()